diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..5313f27f7e74524ca470ebef2d4c04202f4ca81b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/bundle_adjust.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/camera_position_teapot.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/fit_nerf.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/fit_textured_volume.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/implicitron_config.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/.github/nerf_project_logo.gif filter=lfs diff=lfs merge=lfs -text
+pytorch3d/docs/notes/assets/batch_modes.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a269d9ca55511e478e976768a1302f2121c17868
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,199 @@
+# big files
+data_util/face_tracking/3DMM/01_MorphableModel.mat
+data_util/face_tracking/3DMM/3DMM_info.npy
+
+!/deep_3drecon/BFM/.gitkeep
+deep_3drecon/BFM/Exp_Pca.bin
+deep_3drecon/BFM/01_MorphableModel.mat
+deep_3drecon/BFM/BFM_model_front.mat
+deep_3drecon/network/FaceReconModel.pb
+deep_3drecon/checkpoints/*
+
+.vscode
+### Project ignore
+/checkpoints/*
+!/checkpoints/.gitkeep
+/data/*
+!/data/.gitkeep
+infer_out
+rsync
+.idea
+.DS_Store
+bak
+tmp
+*.tar.gz
+mos
+nbs
+/configs_usr/*
+!/configs_usr/.gitkeep
+/egs_usr/*
+!/egs_usr/.gitkeep
+/rnnoise
+#/usr/*
+#!/usr/.gitkeep
+scripts_usr
+
+# Created by .ignore support plugin (hsz.mobi)
+### Python template
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
+deep_3drecon/mesh_renderer/bazel-bin
+deep_3drecon/mesh_renderer/bazel-mesh_renderer
+deep_3drecon/mesh_renderer/bazel-out
+deep_3drecon/mesh_renderer/bazel-testlogs
+
+.nfs*
+infer_outs/*
+
+*.pth
+venv_113/*
+*.pt
+experiments/trials
+flame_3drecon/*
+
+temp/
+/kill.sh
+/datasets
+data_util/imagenet_classes.txt
+process_data_May.sh
+/env_prepare_reproduce.md
+/my_debug.py
+
+utils/metrics/shape_predictor_68_face_landmarks.dat
+*.mp4
+_torchshow/
+*.png
+*.jpg
+
+*.mrc
+
+deep_3drecon/BFM/BFM_exp_idx.mat
+deep_3drecon/BFM/BFM_front_idx.mat
+deep_3drecon/BFM/facemodel_info.mat
+deep_3drecon/BFM/index_mp468_from_mesh35709.npy
+deep_3drecon/BFM/mediapipe_in_bfm53201.npy
+deep_3drecon/BFM/std_exp.txt
+!data/raw/examples/*
\ No newline at end of file
diff --git a/README-zh.md b/README-zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc29e185268d89c8d4e7ca3b58e063e12a8c4533
--- /dev/null
+++ b/README-zh.md
@@ -0,0 +1,137 @@
+# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
+[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/Real3DPortrait) | [English Readme](./README.md)
+
+这个仓库是Real3D-Portrait的官方PyTorch实现, 用于实现单参考图(one-shot)、高视频真实度(video reality)的虚拟人视频合成。您可以访问我们的[项目页面](https://real3dportrait.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/pdf/2401.08503.pdf)以了解技术细节。
+
+
+
+
+
+
+
+# 快速上手!
+## 安装环境
+请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`real3dportrait`
+## 下载预训练与第三方模型
+### 3DMM BFM模型
+下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
+
+
+下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
+```
+deep_3drecon/BFM/
+├── 01_MorphableModel.mat
+├── BFM_exp_idx.mat
+├── BFM_front_idx.mat
+├── BFM_model_front.mat
+├── Exp_Pca.bin
+├── facemodel_info.mat
+├── index_mp468_from_mesh35709.npy
+├── mediapipe_in_bfm53201.npy
+└── std_exp.txt
+```
+
+### 预训练模型
+下载预训练的Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) 提取码: 6x4f
+
+下载完成后,放置全部的文件到`checkpoints`里并解压,文件结构如下:
+```
+checkpoints/
+├── 240126_real3dportrait_orig
+│ ├── audio2secc_vae
+│ │ ├── config.yaml
+│ │ └── model_ckpt_steps_400000.ckpt
+│ └── secc2plane_torso_orig
+│ ├── config.yaml
+│ └── model_ckpt_steps_100000.ckpt
+└── pretrained_ckpts
+ └── mit_b0.pth
+```
+
+## 推理测试
+我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式,并将在未来提供Google Colab方式。我们同时支持音频驱动(Audio-Driven)与视频驱动(Video-Driven):
+
+- 音频驱动场景下,需要至少提供`source image`与`driving audio`
+- 视频驱动场景下,需要至少提供`source image`与`driving expression video`
+
+### Gradio WebUI推理
+启动Gradio WebUI,按照提示上传素材,点击`Generate`按钮即可推理:
+```bash
+python inference/app_real3dportrait.py
+```
+
+### 命令行推理
+首先,切换至项目根目录并启用Conda环境:
+```bash
+cd
+conda activate real3dportrait
+export PYTHON_PATH=./
+```
+音频驱动场景下,需要至少提供source image与driving audio,推理指令:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+视频驱动场景下,需要至少提供source image与driving expression video(作为drv_aud参数),推理指令:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+一些可选参数注释:
+- `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
+- `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
+- `--mouth_amp` 嘴部张幅参数,值越大张幅越大
+- `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
+- `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
+- `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
+- `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
+
+指令示例:
+```bash
+python inference/real3d_infer.py \
+--src_img data/raw/examples/Macron.png \
+--drv_aud data/raw/examples/Obama_5s.wav \
+--drv_pose data/raw/examples/May_5s.mp4 \
+--bg_img data/raw/examples/bg.png \
+--out_name output.mp4 \
+--out_mode concat_debug
+```
+
+## ToDo
+- [x] **Release Pre-trained weights of Real3D-Portrait.**
+- [x] **Release Inference Code of Real3D-Portrait.**
+- [x] **Release Gradio Demo of Real3D-Portrait..**
+- [ ] **Release Google Colab of Real3D-Portrait..**
+- [ ] **Release Training Code of Real3D-Portrait.**
+
+# 引用我们
+如果这个仓库对你有帮助,请考虑引用我们的工作:
+```
+@article{ye2024real3d,
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
+ journal={arXiv preprint arXiv:2401.08503},
+ year={2024}
+}
+@article{ye2023geneface++,
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2305.00787},
+ year={2023}
+}
+@article{ye2023geneface,
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2301.13430},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3b25bd20d9032bc7fccc9ff27c93d52053cbbbd6
--- /dev/null
+++ b/README.md
@@ -0,0 +1,137 @@
+# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
+[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/Real3DPortrait) | [中文文档](./README-zh.md)
+
+This is the official repo of Real3D-Portrait with Pytorch implementation, for one-shot and high video reality talking portrait synthesis. You can visit our [Demo Page](https://real3dportrait.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/pdf/2401.08503.pdf) for technical details.
+
+
+
+
+
+
+
+# Quick Start!
+## Environment Installation
+Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `real3dportrait`.
+## Download Pre-trained & Third-Party Models
+### 3DMM BFM Model
+Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
+
+
+Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
+```
+deep_3drecon/BFM/
+├── 01_MorphableModel.mat
+├── BFM_exp_idx.mat
+├── BFM_front_idx.mat
+├── BFM_model_front.mat
+├── Exp_Pca.bin
+├── facemodel_info.mat
+├── index_mp468_from_mesh35709.npy
+├── mediapipe_in_bfm53201.npy
+└── std_exp.txt
+```
+
+### Pre-trained Real3D-Portrait
+Download Pre-trained Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) with Password 6x4f
+
+Put the zip files in `checkpoints` and unzip them, the file structure will be like this:
+```
+checkpoints/
+├── 240126_real3dportrait_orig
+│ ├── audio2secc_vae
+│ │ ├── config.yaml
+│ │ └── model_ckpt_steps_400000.ckpt
+│ └── secc2plane_torso_orig
+│ ├── config.yaml
+│ └── model_ckpt_steps_100000.ckpt
+└── pretrained_ckpts
+ └── mit_b0.pth
+```
+
+## Inference
+Currently, we provide **CLI** and **Gradio WebUI** for inference, and Google Colab will be provided in the future. We support both Audio-Driven and Video-Driven methods:
+
+- For audio-driven, at least prepare `source image` and `driving audio`
+- For video-driven, at least prepare `source image` and `driving expression video`
+
+### Gradio WebUI
+Run Gradio WebUI demo, upload resouces in webpage,click `Generate` button to inference:
+```bash
+python inference/app_real3dportrait.py
+```
+
+### CLI Inference
+Firstly, switch to project folder and activate conda environment:
+```bash
+cd
+conda activate real3dportrait
+export PYTHON_PATH=./
+```
+For audio-driven, provide source image and driving audio:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+For video-driven, provide source image and driving expression video(as `--drv_aud` parameter):
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+Some optional parameters:
+- `--drv_pose` provide motion pose information, default to be static poses
+- `--bg_img` provide background information, default to be image extracted from source
+- `--mouth_amp` mouth amplitude, higher value leads to wider mouth
+- `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
+- `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
+- `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
+- `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
+
+Commandline example:
+```bash
+python inference/real3d_infer.py \
+--src_img data/raw/examples/Macron.png \
+--drv_aud data/raw/examples/Obama_5s.wav \
+--drv_pose data/raw/examples/May_5s.mp4 \
+--bg_img data/raw/examples/bg.png \
+--out_name output.mp4 \
+--out_mode concat_debug
+```
+
+# ToDo
+- [x] **Release Pre-trained weights of Real3D-Portrait.**
+- [x] **Release Inference Code of Real3D-Portrait.**
+- [x] **Release Gradio Demo of Real3D-Portrait..**
+- [ ] **Release Google Colab of Real3D-Portrait..**
+- [ ] **Release Training Code of Real3D-Portrait.**
+
+# Citation
+If you found this repo helpful to your work, please consider cite us:
+```
+@article{ye2024real3d,
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
+ journal={arXiv preprint arXiv:2401.08503},
+ year={2024}
+}
+@article{ye2023geneface++,
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2305.00787},
+ year={2023}
+}
+@article{ye2023geneface,
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2301.13430},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data_gen/eg3d/convert_to_eg3d_convention.py b/data_gen/eg3d/convert_to_eg3d_convention.py
new file mode 100644
index 0000000000000000000000000000000000000000..45d4e4b11dc69aa82ac0194c0df1b30d0ff020a7
--- /dev/null
+++ b/data_gen/eg3d/convert_to_eg3d_convention.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+import copy
+from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+def _fix_intrinsics(intrinsics):
+ """
+ intrinsics: [3,3], not batch-wise
+ """
+ # unnormalized normalized
+
+ # [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
+ # [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
+ # [ 0, 0, 1 ]] [ 0, 0, 1 ]]
+ intrinsics = np.array(intrinsics).copy()
+ assert intrinsics.shape == (3, 3), intrinsics
+ intrinsics[0,0] = 2985.29/700
+ intrinsics[1,1] = 2985.29/700
+ intrinsics[0,2] = 1/2
+ intrinsics[1,2] = 1/2
+ assert intrinsics[0,1] == 0
+ assert intrinsics[2,2] == 1
+ assert intrinsics[1,0] == 0
+ assert intrinsics[2,0] == 0
+ assert intrinsics[2,1] == 0
+ return intrinsics
+
+# Used in original submission
+def _fix_pose_orig(pose):
+ """
+ pose: [4,4], not batch-wise
+ """
+ pose = np.array(pose).copy()
+ location = pose[:3, 3]
+ radius = np.linalg.norm(location)
+ pose[:3, 3] = pose[:3, 3]/radius * 2.7
+ return pose
+
+
+def get_eg3d_convention_camera_pose_intrinsic(item):
+ """
+ item: a dict during binarize
+
+ """
+ if item['euler'].ndim == 1:
+ angle = convert_to_tensor(copy.copy(item['euler']))
+ trans = copy.deepcopy(item['trans'])
+
+ # handle the difference of euler axis between eg3d and ours
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
+ # angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
+ trans[2] += -10
+ c = -np.dot(R, trans)
+ pose = np.eye(4)
+ pose[:3,:3] = R
+ c *= 0.27 # normalize camera radius
+ c[1] += 0.006 # additional offset used in submission
+ c[2] += 0.161 # additional offset used in submission
+ pose[0,3] = c[0]
+ pose[1,3] = c[1]
+ pose[2,3] = c[2]
+
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
+ pp = 512#112
+ w = 1024#224
+ h = 1024#224
+
+ K = np.eye(3)
+ K[0][0] = focal
+ K[1][1] = focal
+ K[0][2] = w/2.0
+ K[1][2] = h/2.0
+ convention_K = _fix_intrinsics(K)
+
+ Rot = np.eye(3)
+ Rot[0, 0] = 1
+ Rot[1, 1] = -1
+ Rot[2, 2] = -1
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
+ convention_pose = _fix_pose_orig(pose)
+
+ item['c2w'] = pose
+ item['convention_c2w'] = convention_pose
+ item['intrinsics'] = convention_K
+ return item
+ else:
+ num_samples = len(item['euler'])
+ eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
+ trans_all = copy.deepcopy(item['trans']) # [B, 3]
+
+ # handle the difference of euler axis between eg3d and ours
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
+ # eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
+
+ intrinsics = []
+ poses = []
+ convention_poses = []
+ for i in range(num_samples):
+ angle = eulers_all[i]
+ trans = trans_all[i]
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
+ trans[2] += -10
+ c = -np.dot(R, trans)
+ pose = np.eye(4)
+ pose[:3,:3] = R
+ c *= 0.27 # normalize camera radius
+ c[1] += 0.006 # additional offset used in submission
+ c[2] += 0.161 # additional offset used in submission
+ pose[0,3] = c[0]
+ pose[1,3] = c[1]
+ pose[2,3] = c[2]
+
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
+ pp = 512#112
+ w = 1024#224
+ h = 1024#224
+
+ K = np.eye(3)
+ K[0][0] = focal
+ K[1][1] = focal
+ K[0][2] = w/2.0
+ K[1][2] = h/2.0
+ convention_K = _fix_intrinsics(K)
+ intrinsics.append(convention_K)
+
+ Rot = np.eye(3)
+ Rot[0, 0] = 1
+ Rot[1, 1] = -1
+ Rot[2, 2] = -1
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot)
+ convention_pose = _fix_pose_orig(pose)
+ convention_poses.append(convention_pose)
+ poses.append(pose)
+
+ intrinsics = np.stack(intrinsics) # [B, 3, 3]
+ poses = np.stack(poses) # [B, 4, 4]
+ convention_poses = np.stack(convention_poses) # [B, 4, 4]
+ item['intrinsics'] = intrinsics
+ item['c2w'] = poses
+ item['convention_c2w'] = convention_poses
+ return item
diff --git a/data_gen/runs/binarizer_nerf.py b/data_gen/runs/binarizer_nerf.py
new file mode 100644
index 0000000000000000000000000000000000000000..623cd17f6b52c9a981721a8ca14e24af1edfe202
--- /dev/null
+++ b/data_gen/runs/binarizer_nerf.py
@@ -0,0 +1,335 @@
+import os
+import numpy as np
+import math
+import json
+import imageio
+import torch
+import tqdm
+import cv2
+
+from data_util.face3d_helper import Face3DHelper
+from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
+from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+def euler2rot(euler_angle):
+ batch_size = euler_angle.shape[0]
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
+ one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
+ zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
+ rot_x = torch.cat((
+ torch.cat((one, zero, zero), 1),
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
+ ), 2)
+ rot_y = torch.cat((
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
+ torch.cat((zero, one, zero), 1),
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
+ ), 2)
+ rot_z = torch.cat((
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
+ torch.cat((zero, zero, one), 1)
+ ), 2)
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
+
+
+def rot2euler(rot_mat):
+ batch_size = len(rot_mat)
+ # we assert that y in in [-0.5pi, 0.5pi]
+ cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
+ theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
+ theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
+ theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
+ euler_angles = torch.zeros([batch_size, 3])
+ euler_angles[:, 0] = theta_x
+ euler_angles[:, 1] = theta_y
+ euler_angles[:, 2] = theta_z
+ return euler_angles
+
+index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+
+def plot_lm2d(lm2d):
+ WH = 512
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ color = (255,0,0)
+ img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
+ return img
+
+def get_face_rect(lms, h, w):
+ """
+ lms: [68, 2]
+ h, w: int
+ return: [4,]
+ """
+ assert len(lms) == 68
+ # min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
+ cx = int((min_x+max_x)/2.0)
+ cy = int(lms[27, 1])
+ h_w = int((max_x-cx)*1.5)
+ h_h = int((lms[8, 1]-cy)*1.15)
+ rect_x = cx - h_w
+ rect_y = cy - h_h
+ if rect_x < 0:
+ rect_x = 0
+ if rect_y < 0:
+ rect_y = 0
+ rect_w = min(w-1-rect_x, 2*h_w)
+ rect_h = min(h-1-rect_y, 2*h_h)
+ # rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
+ # rect = [rect_x, rect_y, rect_w, rect_h]
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
+ return rect # this x is width, y is height
+
+def get_lip_rect(lms, h, w):
+ """
+ lms: [68, 2]
+ h, w: int
+ return: [4,]
+ """
+ # this x is width, y is height
+ # for lms, lms[:, 0] is width, lms[:, 1] is height
+ assert len(lms) == 68
+ lips = slice(48, 60)
+ lms = lms[lips]
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
+ min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
+ cx = int((min_x+max_x)/2.0)
+ cy = int((min_y+max_y)/2.0)
+ h_w = int((max_x-cx)*1.2)
+ h_h = int((max_y-cy)*1.2)
+
+ h_w = max(h_w, h_h)
+ h_h = h_w
+
+ rect_x = cx - h_w
+ rect_y = cy - h_h
+ rect_w = 2*h_w
+ rect_h = 2*h_h
+ if rect_x < 0:
+ rect_x = 0
+ if rect_y < 0:
+ rect_y = 0
+
+ if rect_x + rect_w > w:
+ rect_x = w - rect_w
+ if rect_y + rect_h > h:
+ rect_y = h - rect_h
+
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
+ return rect # this x is width, y is height
+
+
+# def get_lip_rect(lms, h, w):
+# """
+# lms: [68, 2]
+# h, w: int
+# return: [4,]
+# """
+# assert len(lms) == 68
+# lips = slice(48, 60)
+# # this x is width, y is height
+# xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
+# ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
+# # padding to H == W
+# cx = (xmin + xmax) // 2
+# cy = (ymin + ymax) // 2
+# l = max(xmax - xmin, ymax - ymin) // 2
+# xmin = max(0, cx - l)
+# xmax = min(h, cx + l)
+# ymin = max(0, cy - l)
+# ymax = min(w, cy + l)
+# lip_rect = [xmin, xmax, ymin, ymax]
+# return lip_rect
+
+def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
+ """
+ conds: [b, t=16, h=29]
+ idx: long, time index of the selected frame
+ """
+ idx = max(0, idx)
+ idx = min(idx, conds.shape[0]-1)
+ smo_half_win_size = smo_win_size//2
+ left_i = idx - smo_half_win_size
+ right_i = idx + (smo_win_size - smo_half_win_size)
+ pad_left, pad_right = 0, 0
+ if left_i < 0:
+ pad_left = -left_i
+ left_i = 0
+ if right_i > conds.shape[0]:
+ pad_right = right_i - conds.shape[0]
+ right_i = conds.shape[0]
+ conds_win = conds[left_i:right_i]
+ if pad_left > 0:
+ if pad_option == 'zero':
+ conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
+ elif pad_option == 'edge':
+ edge_value = conds[0][np.newaxis, ...]
+ conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
+ else:
+ raise NotImplementedError
+ if pad_right > 0:
+ if pad_option == 'zero':
+ conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
+ elif pad_option == 'edge':
+ edge_value = conds[-1][np.newaxis, ...]
+ conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
+ else:
+ raise NotImplementedError
+ assert conds_win.shape[0] == smo_win_size
+ return conds_win
+
+
+def load_processed_data(processed_dir):
+ # load necessary files
+ background_img_name = os.path.join(processed_dir, "bg.jpg")
+ assert os.path.exists(background_img_name)
+ head_img_dir = os.path.join(processed_dir, "head_imgs")
+ torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
+ gt_img_dir = os.path.join(processed_dir, "gt_imgs")
+
+ hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
+ mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
+ coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
+ lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
+
+ ret_dict = {}
+
+ ret_dict['bg_img'] = imageio.imread(background_img_name)
+ ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
+ ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
+
+ print("loading lm2d coeff ...")
+ lm2d_arr = np.load(lm2d_npy_name)
+ face_rect_lst = []
+ lip_rect_lst = []
+ for lm2d in lm2d_arr:
+ if len(lm2d) in [468, 478]:
+ lm2d = lm2d[index_lm68_from_lm468]
+ face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
+ lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
+ face_rect_lst.append(face_rect)
+ lip_rect_lst.append(lip_rect)
+ face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
+
+ print("loading fitted 3dmm coeff ...")
+ coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
+ identity_arr = coeff_dict['id']
+ exp_arr = coeff_dict['exp']
+ ret_dict['id'] = identity_arr
+ ret_dict['exp'] = exp_arr
+ euler_arr = ret_dict['euler'] = coeff_dict['euler']
+ trans_arr = ret_dict['trans'] = coeff_dict['trans']
+ print("calculating lm3d ...")
+ idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
+ len_motion = len(idexp_lm3d_arr)
+ video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
+ video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
+ ret_dict['idexp_lm3d'] = idexp_lm3d_arr
+ ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
+ ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
+
+ # now we convert the euler_trans from deep3d convention to adnerf convention
+ eulers = torch.FloatTensor(euler_arr)
+ trans = torch.FloatTensor(trans_arr)
+ rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
+
+ # handle the camera pose to geneface's convention
+ trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
+ rots = rots.permute(0, 2, 1)
+ trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
+ # below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
+ trans = trans / 10.0
+ rots_inv = rots.permute(0, 2, 1)
+ trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
+
+ pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
+ pose[:, :3, :3] = rots_inv
+ pose[:, :3, 3] = trans_inv[:, :, 0]
+ c2w_transform_matrices = pose.numpy()
+
+ # process the audio features used for postnet training
+ print("loading hubert ...")
+ hubert_features = np.load(hubert_npy_name)
+ print("loading Mel and F0 ...")
+ mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
+
+ ret_dict['hubert'] = hubert_features
+ ret_dict['mel'] = mel_f0_features['mel']
+ ret_dict['f0'] = mel_f0_features['f0']
+
+ # obtaining train samples
+ frame_indices = list(range(len_motion))
+ num_train = len_motion // 11 * 10
+ train_indices = frame_indices[:num_train]
+ val_indices = frame_indices[num_train:]
+
+ for split in ['train', 'val']:
+ if split == 'train':
+ indices = train_indices
+ samples = []
+ ret_dict['train_samples'] = samples
+ elif split == 'val':
+ indices = val_indices
+ samples = []
+ ret_dict['val_samples'] = samples
+
+ for idx in indices:
+ sample = {}
+ sample['idx'] = idx
+ sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
+ sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
+ sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
+ # assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
+ sample['face_rect'] = face_rects[idx]
+ sample['lip_rect'] = lip_rect_lst[idx]
+ sample['c2w'] = c2w_transform_matrices[idx]
+ samples.append(sample)
+ return ret_dict
+
+
+class Binarizer:
+ def __init__(self):
+ self.data_dir = 'data/'
+
+ def parse(self, video_id):
+ processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
+ binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
+ out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
+ os.makedirs(binary_dir, exist_ok=True)
+ ret = load_processed_data(processed_dir)
+ mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
+ mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
+ ret.update(mel_f0_dict)
+ np.save(out_fname, ret, allow_pickle=True)
+
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ video_id = args.video_id
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015)
+ face_model.to("cpu")
+ face3d_helper = Face3DHelper()
+
+ binarizer = Binarizer()
+ binarizer.parse(video_id)
+ print(f"Binarization for {video_id} Done!")
diff --git a/data_gen/runs/nerf/process_guide.md b/data_gen/runs/nerf/process_guide.md
new file mode 100644
index 0000000000000000000000000000000000000000..2312d416fcd50cee8656803fe2fdba141e62e86f
--- /dev/null
+++ b/data_gen/runs/nerf/process_guide.md
@@ -0,0 +1,49 @@
+# 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
+
+# Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
+```
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
+mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
+```
+# step1: 提取音频特征, 如mel, f0, hubuert, esperanto
+```
+export CUDA_VISIBLE_DEVICES=0
+export VIDEO_ID=May
+mkdir -p data/processed/videos/${VIDEO_ID}
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
+python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+```
+
+# Step2. 提取图片
+```
+export VIDEO_ID=May
+export CUDA_VISIBLE_DEVICES=0
+mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+```
+
+# Step3. 提取lm2d_mediapipe
+### 提取2D landmark用于之后Fit 3DMM
+### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
+
+```
+export VIDEO_ID=May
+python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+```
+
+# Step3. fit 3dmm
+```
+export VIDEO_ID=May
+export CUDA_VISIBLE_DEVICES=0
+python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+```
+
+# Step4. Binarize
+```
+export VIDEO_ID=May
+python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+```
+可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
\ No newline at end of file
diff --git a/data_gen/runs/nerf/run.sh b/data_gen/runs/nerf/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f028ad9e061c925e51946ff83c27e99c35cbb15c
--- /dev/null
+++ b/data_gen/runs/nerf/run.sh
@@ -0,0 +1,51 @@
+# usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh
+# please place video to data/raw/videos/${VIDEO_ID}.mp4
+VIDEO_ID=$1
+echo Processing $VIDEO_ID
+
+echo Resizing the video to 512x512
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
+mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
+echo Done
+echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+
+echo mkdir -p data/processed/videos/${VIDEO_ID}
+mkdir -p data/processed/videos/${VIDEO_ID}
+echo Done
+
+# extract audio file from the training video
+echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
+echo Done
+
+# extract hubert_mel_f0 from audio
+echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+echo Done
+
+# extract segment images
+echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+echo Done
+
+echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+echo Done
+
+echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+echo Done
+
+pkill -f void*
+echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+echo Done
+
+echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+echo Done
\ No newline at end of file
diff --git a/data_gen/utils/mp_feature_extractors/face_landmarker.py b/data_gen/utils/mp_feature_extractors/face_landmarker.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5904a46809352ef08fd1b3d6948ec4fbc6b7fd
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/face_landmarker.py
@@ -0,0 +1,130 @@
+import mediapipe as mp
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+import numpy as np
+import cv2
+import os
+import copy
+
+# simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
+index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
+# lm141 without iris
+index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
+
+# face alignment lm68
+index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+# used for weights for key parts
+unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
+index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
+index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
+
+index_yaw_from_lm68 = list(range(0, 17))
+index_brow_from_lm68 = list(range(17, 27))
+index_nose_from_lm68 = list(range(27, 36))
+index_eye_from_lm68 = list(range(36, 48))
+index_mouth_from_lm68 = list(range(48, 68))
+
+
+def read_video_to_frames(video_name):
+ frames = []
+ cap = cv2.VideoCapture(video_name)
+ while cap.isOpened():
+ ret, frame_bgr = cap.read()
+ if frame_bgr is None:
+ break
+ frames.append(frame_bgr)
+ frames = np.stack(frames)
+ frames = np.flip(frames, -1) # BGR ==> RGB
+ return frames
+
+class MediapipeLandmarker:
+ def __init__(self):
+ model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ print("downloading face_landmarker model from mediapipe...")
+ model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
+ os.system(f"wget {model_url}")
+ os.system(f"mv face_landmarker.task {model_path}")
+ print("download success")
+ base_options = python.BaseOptions(model_asset_path=model_path)
+ self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
+ num_faces=1)
+ self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
+ num_faces=1)
+
+ def extract_lm478_from_img_name(self, img_name):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_lm478 = self.extract_lm478_from_img(img)
+ return img_lm478
+
+ def extract_lm478_from_img(self, img):
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
+ H, W, _ = img.shape
+ img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
+ return img_lm478
+
+ def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
+ frames = read_video_to_frames(video_name)
+ img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
+ return img_lm478, vid_lm478
+
+ def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
+ """
+ frames: RGB, uint8
+ anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
+ """
+ img_mpldms = []
+ vid_mpldms = []
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
+ vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
+
+ for i in range(len(frames)):
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
+ vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
+ try:
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
+ vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
+ except:
+ print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
+ vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
+ img_mpldms.append(img_face_landmarks)
+ vid_mpldms.append(vid_face_landmarks)
+ img_lm478 = np.stack(img_mpldms)[..., :2]
+ vid_lm478 = np.stack(vid_mpldms)[..., :2]
+ bs, H, W, _ = frames.shape
+ img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
+ vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
+ return img_lm478, vid_lm478
+
+ def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
+ img_lm68 = img_lm478[:, index_lm68_from_lm478]
+ vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
+ combined_lm68 = copy.deepcopy(img_lm68)
+ combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
+ combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
+ combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
+ return combined_lm68
+
+ def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
+ combined_lm478 = copy.deepcopy(vid_lm478)
+ combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
+ combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
+ return combined_lm478
+
+if __name__ == '__main__':
+ landmarker = MediapipeLandmarker()
+ ret = landmarker.extract_lm478_from_video_name("00000.mp4")
diff --git a/data_gen/utils/mp_feature_extractors/face_landmarker.task b/data_gen/utils/mp_feature_extractors/face_landmarker.task
new file mode 100644
index 0000000000000000000000000000000000000000..fedb14de6d2b6708a56c04ae259783e23404c1aa
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/face_landmarker.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
+size 3758596
diff --git a/data_gen/utils/mp_feature_extractors/mp_segmenter.py b/data_gen/utils/mp_feature_extractors/mp_segmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..14ed79201e82c13cdcf67dd9d75ea1b945edfbe1
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/mp_segmenter.py
@@ -0,0 +1,274 @@
+import os
+import copy
+import numpy as np
+import tqdm
+import mediapipe as mp
+import torch
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
+from utils.commons.tensor_utils import convert_to_np
+from sklearn.neighbors import NearestNeighbors
+
+def scatter_np(condition_img, classSeg=5):
+# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
+ batch, c, height, width = condition_img.shape
+ # if height != label_size[0] or width != label_size[1]:
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
+ input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
+ np.put_along_axis(input_label, condition_img, 1, 1)
+ return input_label
+
+def scatter(condition_img, classSeg=19):
+# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
+ batch, c, height, width = condition_img.size()
+ # if height != label_size[0] or width != label_size[1]:
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
+ input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
+ return input_label.scatter_(1, condition_img.long(), 1)
+
+def encode_segmap_mask_to_image(segmap):
+ # rgb
+ _,h,w = segmap.shape
+ encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
+ for i, color in enumerate(colors):
+ mask = segmap[i].astype(int)
+ index = np.where(mask != 0)
+ encoded_img[index[0], index[1], :] = np.array(color)
+ return encoded_img.astype(np.uint8)
+
+def decode_segmap_mask_from_image(encoded_img):
+ # rgb
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
+ bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
+ hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
+ body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
+ face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
+ clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
+ others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
+ segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
+ return segmap.astype(np.uint8)
+
+def read_video_frame(video_name, frame_id):
+ # https://blog.csdn.net/bby1987/article/details/108923361
+ # frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
+ # fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
+ # width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
+ # height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
+ # video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
+ # video_capture.release()
+ vr = cv2.VideoCapture(video_name)
+ vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
+ _, frame = vr.read()
+ return frame
+
+def decode_segmap_mask_from_segmap_video_frame(video_frame):
+ # video_frame: 0~255 BGR, obtained by read_video_frame
+ def assign_values(array):
+ remainder = array % 40 # 计算数组中每个值与40的余数
+ assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
+ return assigned_values
+ segmap = video_frame.mean(-1)
+ segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ return segmap.astype(np.uint8)
+
+def extract_background(img_lst, segmap_lst=None):
+ """
+ img_lst: list of rgb ndarray
+ """
+ # only use 1/20 images
+ num_frames = len(img_lst)
+ img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
+
+ if segmap_lst is not None:
+ segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
+ assert len(img_lst) == len(segmap_lst)
+ # get H/W
+ h, w = img_lst[0].shape[:2]
+
+ # nearest neighbors
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
+ distss = []
+ for idx, img in enumerate(img_lst):
+ if segmap_lst is not None:
+ segmap = segmap_lst[idx]
+ else:
+ segmap = seg_model._cal_seg_map(img)
+ bg = (segmap[0]).astype(bool)
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ dists, _ = nbrs.kneighbors(all_xys)
+ distss.append(dists)
+
+ distss = np.stack(distss)
+ max_dist = np.max(distss, 0)
+ max_id = np.argmax(distss, 0)
+
+ bc_pixs = max_dist > 10 # 5
+ bc_pixs_id = np.nonzero(bc_pixs)
+ bc_ids = max_id[bc_pixs]
+
+ num_pixs = distss.shape[1]
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
+
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
+ bg_img = bg_img.reshape(h, w, 3)
+
+ max_dist = max_dist.reshape(h, w)
+ bc_pixs = max_dist > 10 # 5
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ distances, indices = nbrs.kneighbors(bg_xys)
+ bg_fg_xys = fg_xys[indices[:, 0]]
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ return bg_img
+
+
+class MediapipeSegmenter:
+ def __init__(self):
+ model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ print("downloading segmenter model from mediapipe...")
+ os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
+ os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
+ print("download success")
+ base_options = python.BaseOptions(model_asset_path=model_path)
+ self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
+ self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
+
+ def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True, debug_fill=False):
+ segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
+ assert return_onehot_mask or return_segmap_image # you should at least return one
+ segmap_masks = []
+ segmap_images = []
+ for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
+ # for i in range(len(imgs)):
+ img = imgs[i]
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ out = segmenter.segment_for_video(mp_image, 40 * i)
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
+ if debug_fill:
+ # print(f'segmap {segmap}')
+ for x in range(-80 + 1, 0):
+ for y in range(200, 350):
+ segmap[x][y] = 4
+
+ if return_onehot_mask:
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ segmap_masks.append(segmap_mask)
+ if return_segmap_image:
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
+ segmap_image = (segmap_image * 40).astype(np.uint8)
+ segmap_images.append(segmap_image)
+
+ if return_onehot_mask and return_segmap_image:
+ return segmap_masks, segmap_images
+ elif return_onehot_mask:
+ return segmap_masks
+ elif return_segmap_image:
+ return segmap_images
+
+ def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
+ """
+ segmenter: vision.ImageSegmenter.create_from_options(options)
+ img: numpy, [H, W, 3], 0~255
+ segmap: [C, H, W]
+ 0 - background
+ 1 - hair
+ 2 - body-skin
+ 3 - face-skin
+ 4 - clothes
+ 5 - others (accessories)
+ """
+ assert img.ndim == 3
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ out = segmenter.segment(image)
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
+ if return_onehot_mask:
+ segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ return segmap
+
+ def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
+ """
+ img: [h,w,c], img is in 0~255, np
+ """
+ #
+ img = copy.deepcopy(img)
+ if mode == 'head':
+ selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ # selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
+ elif mode == 'person':
+ selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'torso':
+ selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'torso_with_bg':
+ selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'bg':
+ selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'full':
+ pass
+ else:
+ raise NotImplementedError()
+ return img, selected_mask
+
+ def _seg_out_img(self, img, segmenter=None, mode='head'):
+ """
+ imgs [H, W, 3] 0-255
+ return : person_img [B, 3, H, W]
+ """
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
+ segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
+ return self._seg_out_img_with_segmap(img, segmap, mode=mode)
+
+ def seg_out_imgs(self, img, mode='head'):
+ """
+ api for pytorch img, -1~1
+ img: [B, 3, H, W], -1~1
+ """
+ device = img.device
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
+ img = ((img + 1) * 127.5).astype(np.uint8)
+ img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
+ out_lst = []
+ for im in img_lst:
+ out = self._seg_out_img(im, mode=mode)
+ out_lst.append(out)
+ seg_imgs = np.stack(out_lst) # [B, H, W, 3]
+ seg_imgs = (seg_imgs - 127.5) / 127.5
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
+ return seg_imgs
+
+if __name__ == '__main__':
+ import imageio, cv2, tqdm
+ import torchshow as ts
+ img = imageio.imread("1.png")
+ img = cv2.resize(img, (512,512))
+
+ seg_model = MediapipeSegmenter()
+ img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
+ img = (img-127.5)/127.5
+ out = seg_model.seg_out_imgs(img, 'torso')
+ ts.save(out,"torso.png")
+ out = seg_model.seg_out_imgs(img, 'head')
+ ts.save(out,"head.png")
+ out = seg_model.seg_out_imgs(img, 'bg')
+ ts.save(out,"bg.png")
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
+ img = ((img + 1) * 127.5).astype(np.uint8)
+ bg = extract_background(img)
+ ts.save(bg,"bg2.png")
diff --git a/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite b/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
new file mode 100644
index 0000000000000000000000000000000000000000..9ebdec318f4426502f8d825b8f0332c3e20e29b7
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
+size 16371837
diff --git a/data_gen/utils/path_converter.py b/data_gen/utils/path_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e862fb1810da7c6771d358a39a4043f93c9795
--- /dev/null
+++ b/data_gen/utils/path_converter.py
@@ -0,0 +1,24 @@
+import os
+
+
+class PathConverter():
+ def __init__(self):
+ self.prefixs = {
+ "vid": "/video/",
+ "gt": "/gt_imgs/",
+ "head": "/head_imgs/",
+ "torso": "/torso_imgs/",
+ "person": "/person_imgs/",
+ "torso_with_bg": "/torso_with_bg_imgs/",
+ "single_bg": "/bg_img/",
+ "bg": "/bg_imgs/",
+ "segmaps": "/segmaps/",
+ "inpaint_torso": "/inpaint_torso_imgs/",
+ "com": "/com_imgs/",
+ "inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
+ }
+
+ def to(self, path: str, old_pattern: str, new_pattern: str):
+ return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
+
+pc = PathConverter()
\ No newline at end of file
diff --git a/data_gen/utils/process_audio/extract_hubert.py b/data_gen/utils/process_audio/extract_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..85af486a983b7706f05ea2861565bc7b32d480dd
--- /dev/null
+++ b/data_gen/utils/process_audio/extract_hubert.py
@@ -0,0 +1,95 @@
+from transformers import Wav2Vec2Processor, HubertModel
+import soundfile as sf
+import numpy as np
+import torch
+import os
+from utils.commons.hparams import set_hparams, hparams
+
+
+wav2vec2_processor = None
+hubert_model = None
+
+
+def get_hubert_from_16k_wav(wav_16k_name):
+ speech_16k, _ = sf.read(wav_16k_name)
+ hubert = get_hubert_from_16k_speech(speech_16k)
+ return hubert
+
+@torch.no_grad()
+def get_hubert_from_16k_speech(speech, device="cuda:0"):
+ global hubert_model, wav2vec2_processor
+ local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
+ if hubert_model is None:
+ print("Loading the HuBERT Model...")
+ if os.path.exists(local_path):
+ hubert_model = HubertModel.from_pretrained(local_path)
+ else:
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+ hubert_model = hubert_model.to(device)
+ if wav2vec2_processor is None:
+ print("Loading the Wav2Vec2 Processor...")
+ if os.path.exists(local_path):
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
+ else:
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
+
+ if speech.ndim ==2:
+ speech = speech[:, 0] # [T, 2] ==> [T,]
+
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
+ input_values_all = input_values_all.to(device)
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
+ kernel = 400
+ stride = 320
+ clip_length = stride * 1000
+ num_iter = input_values_all.shape[1] // clip_length
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
+ res_lst = []
+ for i in range(num_iter):
+ if i == 0:
+ start_idx = 0
+ end_idx = clip_length - stride + kernel
+ else:
+ start_idx = clip_length * i
+ end_idx = start_idx + (clip_length - stride + kernel)
+ input_values = input_values_all[:, start_idx: end_idx]
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
+ res_lst.append(hidden_states[0])
+ if num_iter > 0:
+ input_values = input_values_all[:, clip_length * num_iter:]
+ else:
+ input_values = input_values_all
+
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
+ res_lst.append(hidden_states[0])
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
+
+ assert abs(ret.shape[0] - expected_T) <= 1
+ if ret.shape[0] < expected_T: # if skipping the last short
+ ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
+ else:
+ ret = ret[:expected_T]
+
+ return ret
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ person_id = args.video_id
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
+ hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
+ speech_16k, _ = sf.read(wav_16k_name)
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
+ np.save(hubert_npy_name, hubert_hidden.detach().numpy())
+ print(f"Saved at {hubert_npy_name}")
diff --git a/data_gen/utils/process_audio/extract_mel_f0.py b/data_gen/utils/process_audio/extract_mel_f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d29fe8515f61448431af70c5d3169856b4cef9
--- /dev/null
+++ b/data_gen/utils/process_audio/extract_mel_f0.py
@@ -0,0 +1,148 @@
+import numpy as np
+import torch
+import glob
+import os
+import tqdm
+import librosa
+import parselmouth
+from utils.commons.pitch_utils import f0_to_coarse
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from utils.commons.os_utils import multiprocess_glob
+from utils.audio.io import save_wav
+
+from moviepy.editor import VideoFileClip
+from utils.commons.hparams import hparams, set_hparams
+
+def resample_wav(wav_name, out_name, sr=16000):
+ wav_raw, sr = librosa.core.load(wav_name, sr=sr)
+ save_wav(wav_raw, out_name, sr)
+
+def split_wav(mp4_name, wav_name=None):
+ if wav_name is None:
+ wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
+ if os.path.exists(wav_name):
+ return wav_name
+ os.makedirs(os.path.dirname(wav_name), exist_ok=True)
+
+ video = VideoFileClip(mp4_name,verbose=False)
+ dur = video.duration
+ audio = video.audio
+ assert audio is not None
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
+ return wav_name
+
+def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
+ '''compute right padding (final frame) or both sides padding (first and final frames)
+ '''
+ assert pad_sides in (1, 2)
+ # return int(fsize // 2)
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ else:
+ return pad // 2, pad // 2 + pad % 2
+
+def extract_mel_from_fname(wav_path,
+ fft_size=512,
+ hop_size=320,
+ win_length=512,
+ window="hann",
+ num_mels=80,
+ fmin=80,
+ fmax=7600,
+ eps=1e-6,
+ sample_rate=16000,
+ min_level_db=-100):
+ if isinstance(wav_path, str):
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
+ else:
+ wav = wav_path
+
+ # get amplitude spectrogram
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, center=False)
+ spc = np.abs(x_stft) # (n_bins, T)
+
+ # get mel basis
+ fmin = 0 if fmin == -1 else fmin
+ fmax = sample_rate / 2 if fmax == -1 else fmax
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel = mel_basis @ spc
+
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
+ mel = mel.T
+
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+
+ return wav.T, mel
+
+def extract_f0_from_wav_and_mel(wav, mel,
+ hop_size=320,
+ audio_sample_rate=16000,
+ ):
+ time_step = hop_size / audio_sample_rate * 1000
+ f0_min = 80
+ f0_max = 750
+ f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
+ time_step=time_step / 1000, voicing_threshold=0.6,
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+
+ delta_l = len(mel) - len(f0)
+ assert np.abs(delta_l) <= 8
+ if delta_l > 0:
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
+ f0 = f0[:len(mel)]
+ pitch_coarse = f0_to_coarse(f0)
+ return f0, pitch_coarse
+
+
+def extract_mel_f0_from_fname(wav_name=None, out_name=None):
+ try:
+ out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+
+ wav, mel = extract_mel_from_fname(wav_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ out_dict = {
+ "mel": mel, # [T, 80]
+ "f0": f0,
+ }
+ np.save(out_name, out_dict)
+ except Exception as e:
+ print(e)
+
+def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
+ if mp4_name.endswith(".mp4"):
+ wav_name = split_wav(mp4_name, wav_name)
+ if out_name is None:
+ out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
+ elif mp4_name.endswith(".wav"):
+ wav_name = mp4_name
+ if out_name is None:
+ out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
+
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+
+ wav, mel = extract_mel_from_fname(wav_name)
+
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ out_dict = {
+ "mel": mel, # [T, 80]
+ "f0": f0,
+ }
+ np.save(out_name, out_dict)
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ person_id = args.video_id
+
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
+ out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
+ extract_mel_f0_from_video_name(wav_16k_name, out_name)
+ print(f"Saved at {out_name}")
\ No newline at end of file
diff --git a/data_gen/utils/process_audio/resample_audio_to_16k.py b/data_gen/utils/process_audio/resample_audio_to_16k.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc353b9385dc22c30256eb7dedbfb610cd33036
--- /dev/null
+++ b/data_gen/utils/process_audio/resample_audio_to_16k.py
@@ -0,0 +1,49 @@
+import os, glob
+from utils.commons.os_utils import multiprocess_glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+
+def extract_wav16k_job(audio_name:str):
+ out_path = audio_name.replace("/audio_raw/","/audio/",1)
+ assert out_path != audio_name # prevent inplace
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+
+ cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
+ parser.add_argument("--ds_name", default='CMLR')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ print(f"args {args}")
+
+ aud_dir = args.aud_dir
+ ds_name = args.ds_name
+ if ds_name in ['CMLR']:
+ aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
+ aud_names = multiprocess_glob(aud_name_pattern)
+ else:
+ raise NotImplementedError()
+ aud_names = sorted(aud_names)
+ print(f"total audio number : {len(aud_names)}")
+ print(f"first {aud_names[0]} last {aud_names[-1]}")
+ # exit()
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(aud_names) // total_process
+ if process_id == total_process:
+ aud_names = aud_names[process_id * num_samples_per_process : ]
+ else:
+ aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+
diff --git a/data_gen/utils/process_image/extract_lm2d.py b/data_gen/utils/process_image/extract_lm2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee0ecc02dc94a04b69682a05a7b089d9cd4c8d6
--- /dev/null
+++ b/data_gen/utils/process_image/extract_lm2d.py
@@ -0,0 +1,197 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import sys
+
+import glob
+import cv2
+import tqdm
+import numpy as np
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+import warnings
+warnings.filterwarnings('ignore')
+
+import random
+random.seed(42)
+
+import pickle
+import json
+import gzip
+from typing import Any
+
+def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
+ if is_json:
+ if is_gzip:
+ with gzip.open(filename, "r", encoding="utf-8") as f:
+ loaded_object = json.load(f)
+ return loaded_object
+ else:
+ with open(filename, "r", encoding="utf-8") as f:
+ loaded_object = json.load(f)
+ return loaded_object
+ else:
+ if is_gzip:
+ with gzip.open(filename, "rb") as f:
+ loaded_object = pickle.load(f)
+ return loaded_object
+ else:
+ with open(filename, "rb") as f:
+ loaded_object = pickle.load(f)
+ return loaded_object
+
+def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
+ if is_json:
+ if is_gzip:
+ with gzip.open(filename, "w", encoding="utf-8") as f:
+ json.dump(content, f)
+ else:
+ with open(filename, "w", encoding="utf-8") as f:
+ json.dump(content, f)
+ else:
+ if is_gzip:
+ with gzip.open(filename, "wb") as f:
+ pickle.dump(content, f)
+ else:
+ with open(filename, "wb") as f:
+ pickle.dump(content, f)
+
+face_landmarker = None
+
+def extract_lms_mediapipe_job(img):
+ if img is None:
+ return None
+ global face_landmarker
+ if face_landmarker is None:
+ face_landmarker = MediapipeLandmarker()
+ lm478 = face_landmarker.extract_lm478_from_img(img)
+ return lm478
+
+def extract_landmark_job(img_name):
+ try:
+ # if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
+ # print(1)
+ # input()
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
+ if os.path.exists(out_name):
+ print("out exists, skip...")
+ return
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+ img = cv2.imread(img_name)[:,:,::-1]
+
+ if img is not None:
+ lm468 = extract_lms_mediapipe_job(img)
+ if lm468 is not None:
+ np.save(out_name, lm468)
+ # print("Hahaha, solve one item!!!")
+ except Exception as e:
+ print(e)
+ pass
+
+def out_exist_job(img_name):
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
+ if os.path.exists(out_name):
+ return None
+ else:
+ return img_name
+
+# def get_todo_img_names(img_names):
+# todo_img_names = []
+# for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
+# if res is not None:
+# todo_img_names.append(res)
+# return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
+ parser.add_argument("--load_img_names", action="store_true")
+
+ args = parser.parse_args()
+ print(f"args {args}")
+ img_dir = args.img_dir
+ img_names_file = os.path.join(img_dir, args.img_names_file)
+ if args.load_img_names:
+ img_names = load_file(img_names_file)
+ print(f"load image names from {img_names_file}")
+ else:
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ img_names = sorted(img_names)
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+ img_names = sorted(img_names)
+ elif args.ds_name == "PanoHeadGen":
+ # img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
+ img_name_patterns = ["ref/*/*.png"]
+ img_names = []
+ for img_name_pattern in img_name_patterns:
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
+ img_names_part = glob.glob(img_name_pattern_full)
+ img_names.extend(img_names_part)
+ img_names = sorted(img_names)
+
+ # save image names
+ if not args.load_img_names:
+ save_file(img_names_file, img_names)
+ print(f"save image names in {img_names_file}")
+
+ print(f"total images number: {len(img_names)}")
+
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ # if not args.reset:
+ # img_names = get_todo_img_names(img_names)
+
+
+ print(f"todo_image {img_names[:10]}")
+ print(f"processing images number in this process: {len(img_names)}")
+ # print(f"todo images number: {len(img_names)}")
+ # input()
+ # exit()
+
+ if args.num_workers == 1:
+ index = 0
+ for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
+ try:
+ extract_landmark_job(img_name)
+ except Exception as e:
+ print(e)
+ pass
+ if index % max(1, int(len(img_names) * 0.003)) == 0:
+ print(f"processed {index} / {len(img_names)}")
+ sys.stdout.flush()
+ index += 1
+ else:
+ for i, res in multiprocess_run_tqdm(
+ extract_landmark_job, img_names,
+ num_workers=args.num_workers,
+ desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
+ # if index % max(1, int(len(img_names) * 0.003)) == 0:
+ print(f"processed {i+1} / {len(img_names)}")
+ sys.stdout.flush()
+ print(f"Root {args.process_id}: Finished extracting.")
\ No newline at end of file
diff --git a/data_gen/utils/process_image/extract_segment_imgs.py b/data_gen/utils/process_image/extract_segment_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..408a6d1b6229e9bd7e2aa1c7c7cdeb067cc0ae7f
--- /dev/null
+++ b/data_gen/utils/process_image/extract_segment_imgs.py
@@ -0,0 +1,114 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+
+import glob
+import cv2
+import tqdm
+import numpy as np
+import PIL
+from utils.commons.tensor_utils import convert_to_np
+import torch
+import mediapipe as mp
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
+seg_model = MediapipeSegmenter()
+
+
+def extract_segment_job(img_name):
+ try:
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ segmap = seg_model._cal_seg_map(img)
+ bg_img = extract_background([img], [segmap])
+ out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
+ save_rgb_image_to_path(bg_img, out_img_name)
+
+ com_img = img.copy()
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
+ com_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
+ save_rgb_image_to_path(com_img, out_img_name)
+
+ for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
+ out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
+ out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
+ out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
+ except: pass
+ cv2.imwrite(out_img_name, out_img)
+
+ inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
+ save_rgb_image_to_path(inpaint_torso_img, out_img_name)
+ inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
+ save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
+ return 0
+ except Exception as e:
+ print(e)
+ return 1
+
+def out_exist_job(img_name):
+ out_name1 = img_name.replace("/images_512/", "/head_imgs/")
+ out_name2 = img_name.replace("/images_512/", "/com_imgs/")
+ out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
+
+ if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
+ return None
+ else:
+ return img_name
+
+def get_todo_img_names(img_names):
+ todo_img_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
+ if res is not None:
+ todo_img_names.append(res)
+ return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='./images_512')
+ # parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--num_workers", default=1, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+
+ args = parser.parse_args()
+ img_dir = args.img_dir
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+
+ img_names = sorted(img_names)
+ random.seed(args.seed)
+ random.shuffle(img_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ img_names = get_todo_img_names(img_names)
+ print(f"todo images number: {len(img_names)}")
+
+ for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_image/fit_3dmm_landmark.py b/data_gen/utils/process_image/fit_3dmm_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fde7d94d919ab2b582fe7ac2e1a11fbe8129fad
--- /dev/null
+++ b/data_gen/utils/process_image/fit_3dmm_landmark.py
@@ -0,0 +1,369 @@
+from numpy.core.numeric import require
+from numpy.lib.function_base import quantile
+import torch
+import torch.nn.functional as F
+import copy
+import numpy as np
+
+import os
+import sys
+import cv2
+import argparse
+import tqdm
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+import pickle
+
+face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
+face_model.to("cuda")
+
+
+index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+
+dir_path = os.path.dirname(os.path.realpath(__file__))
+
+LAMBDA_REG_ID = 0.3
+LAMBDA_REG_EXP = 0.05
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def cal_lan_loss_mp(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan).pow(2)
+ # loss = (proj_lan - gt_lan).abs()
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+ weights = torch.ones_like(loss)
+ weights[:, eye] = 5
+ weights[:, inner_lip] = 2
+ weights[:, outer_lip] = 2
+ weights[:, unmatch_mask] = 0
+ loss = loss * weights
+ return torch.mean(loss)
+
+def cal_lan_loss(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan)** 2
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
+ weights = torch.zeros_like(loss)
+ weights = torch.ones_like(loss)
+ weights[:, 36:48, :] = 3 # eye 12 points
+ weights[:, -8:, :] = 3 # inner lip 8 points
+ weights[:, 28:31, :] = 3 # nose 3 points
+ loss = loss * weights
+ return torch.mean(loss)
+
+def set_requires_grad(tensor_list):
+ for tensor in tensor_list:
+ tensor.requires_grad = True
+
+def read_video_to_frames(img_name):
+ frames = []
+ cap = cv2.VideoCapture(img_name)
+ while cap.isOpened():
+ ret, frame_bgr = cap.read()
+ if frame_bgr is None:
+ break
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ frames.append(frame_rgb)
+ return np.stack(frames)
+
+@torch.enable_grad()
+def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_h, img_w = img.shape[0], img.shape[0]
+ assert img_h == img_w
+ num_frames = 1
+
+ lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
+ if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
+ lms = np.load(lm_name)
+ else:
+ # print("lms_2d file not found, try to extract it from image...")
+ try:
+ landmarker = MediapipeLandmarker()
+ lms = landmarker.extract_lm478_from_img_name(img_name)
+ # lms = landmarker.extract_lm478_from_img(img)
+ except Exception as e:
+ print(e)
+ return
+ if lms is None:
+ print("get None lms_2d, please check whether each frame has one head, exiting...")
+ return
+ lms = lms[:468].reshape([468,2])
+ lms = torch.FloatTensor(lms).to(device=device)
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+
+ if keypoint_mode == 'mediapipe':
+ cal_lan_loss_fn = cal_lan_loss_mp
+ out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
+ else:
+ cal_lan_loss_fn = cal_lan_loss
+ out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+
+ id_dim, exp_dim = 80, 64
+ sel_ids = np.arange(0, num_frames, 40)
+ sel_num = sel_ids.shape[0]
+ arg_focal = face_model.focal
+
+ h = w = face_model.center * 2
+ img_scale_factor = img_h / h
+ lms /= img_scale_factor
+ cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
+
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
+
+ focal_length = lms.new_zeros(1, requires_grad=True)
+ focal_length.data += arg_focal
+
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
+
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
+
+ # 其他参数初始化,先训练euler和trans
+ for _ in range(200):
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
+ loss = loss_lan
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_frame.step()
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ for param_group in optimizer_frame.param_groups:
+ param_group['lr'] = 0.1
+
+ # "jointly roughly training id exp euler trans"
+ for _ in range(200):
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.detach())
+ loss_regid = torch.mean(id_para*id_para) # 正则化
+ loss_regexp = torch.mean(exp_para * exp_para)
+
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
+ optimizer_idexp.zero_grad()
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_idexp.step()
+ optimizer_frame.step()
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ # start fine training, intialize from the roughly trained results
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ id_para_.data = id_para.data.clone()
+ id_para = id_para_
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ exp_para_.data = exp_para.data.clone()
+ exp_para = exp_para_
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ euler_angle_.data = euler_angle.data.clone()
+ euler_angle = euler_angle_
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans_.data = trans.data.clone()
+ trans = trans_
+
+ batch_size = 1
+
+ # "fine fitting the 3DMM in batches"
+ for i in range(int((num_frames-1)/batch_size+1)):
+ if (i+1)*batch_size > num_frames:
+ start_n = num_frames-batch_size
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
+ else:
+ start_n = i*batch_size
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
+ sel_lms = lms[sel_ids]
+
+ sel_id_para = id_para.new_zeros(
+ (batch_size, id_dim), requires_grad=True)
+ sel_id_para.data = id_para[sel_ids].clone()
+ sel_exp_para = exp_para.new_zeros(
+ (batch_size, exp_dim), requires_grad=True)
+ sel_exp_para.data = exp_para[sel_ids].clone()
+ sel_euler_angle = euler_angle.new_zeros(
+ (batch_size, 3), requires_grad=True)
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
+ sel_trans.data = trans[sel_ids].clone()
+
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+
+ for j in range(50):
+ proj_geo = face_model.compute_for_landmark_fit(
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.unsqueeze(0).detach())
+
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
+ optimizer_cur_batch.zero_grad()
+ loss.backward()
+ optimizer_cur_batch.step()
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
+ id_para[sel_ids].data = sel_id_para.data.clone()
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
+ trans[sel_ids].data = sel_trans.data.clone()
+
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
+ if save:
+ np.save(out_name, coeff_dict, allow_pickle=True)
+
+ if debug:
+ import imageio
+ debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
+ try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
+ except: pass
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
+ lm68s = lm68s * img_scale_factor
+ lms = lms * img_scale_factor
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+ lm68s = lm68s.astype(int)
+ lm68s = lm68s.reshape([-1,2])
+ lms = lms.cpu().numpy().astype(int).reshape([-1,2])
+ for lm in lm68s:
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
+ for gt_lm in lms:
+ img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
+ imageio.imwrite(debug_name, img)
+ print(f"debug img saved at {debug_name}")
+ return coeff_dict
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
+ # if os.path.exists(out_name) or not os.path.exists(lms_name):
+ if os.path.exists(out_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_img_names(img_names):
+ todo_img_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
+ if res is not None:
+ todo_img_names.append(res)
+ return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
+ parser.add_argument("--debug", action='store_true')
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--device", default="cuda:0", type=str)
+ parser.add_argument("--output_log", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ img_dir = args.img_dir
+ load_names = args.load_names
+
+ print(f"args {args}")
+
+ if args.ds_name == 'single_img':
+ img_names = [img_dir]
+ else:
+ img_names_path = os.path.join(img_dir, "img_dir.pkl")
+ if os.path.exists(img_names_path) and load_names:
+ print(f"loading vid names from {img_names_path}")
+ img_names = load_file(img_names_path)
+ else:
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ img_names = sorted(img_names)
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+ img_names = sorted(img_names)
+ elif args.ds_name == "PanoHeadGen":
+ img_name_patterns = ["ref/*/*.png"]
+ img_names = []
+ for img_name_pattern in img_name_patterns:
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
+ img_names_part = glob.glob(img_name_pattern_full)
+ img_names.extend(img_names_part)
+ img_names = sorted(img_names)
+ print(f"saving image names to {img_names_path}")
+ save_file(img_names_path, img_names)
+
+ # import random
+ # random.seed(args.seed)
+ # random.shuffle(img_names)
+
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
+ face_model.to(torch.device(args.device))
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1 and process_id >= 0
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+ print(f"image names number (before fileter): {len(img_names)}")
+
+
+ if not args.reset:
+ img_names = get_todo_img_names(img_names)
+
+ print(f"image names number (after fileter): {len(img_names)}")
+ for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
+ img_name = img_names[i]
+ try:
+ fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
+ except Exception as e:
+ print(img_name, e)
+ if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
+ print(f"process {process_id}: {i + 1} / {len(img_names)} done")
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+ print(f"process {process_id}: fitting 3dmm all done")
+
diff --git a/data_gen/utils/process_video/euler2quaterion.py b/data_gen/utils/process_video/euler2quaterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3fd35af0e26285dafac2931fad5904e9d30321a
--- /dev/null
+++ b/data_gen/utils/process_video/euler2quaterion.py
@@ -0,0 +1,35 @@
+import numpy as np
+import torch
+import math
+import numba
+from scipy.spatial.transform import Rotation as R
+
+def euler2quaterion(euler, use_radian=True):
+ """
+ euler: np.array, [batch, 3]
+ return: the quaterion, np.array, [batch, 4]
+ """
+ r = R.from_euler('xyz',euler, degrees=not use_radian)
+ return r.as_quat()
+
+def quaterion2euler(quat, use_radian=True):
+ """
+ quat: np.array, [batch, 4]
+ return: the euler, np.array, [batch, 3]
+ """
+ r = R.from_quat(quat)
+ return r.as_euler('xyz', degrees=not use_radian)
+
+def rot2quaterion(rot):
+ r = R.from_matrix(rot)
+ return r.as_quat()
+
+def quaterion2rot(quat):
+ r = R.from_quat(quat)
+ return r.as_matrix()
+
+if __name__ == '__main__':
+ euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
+ q = euler2quaterion(euler, use_radian=False)
+ e = quaterion2euler(q, use_radian=False)
+ print(" ")
diff --git a/data_gen/utils/process_video/extract_blink.py b/data_gen/utils/process_video/extract_blink.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6d27bb077d401a9c8e9b5b19b121c538db9e037
--- /dev/null
+++ b/data_gen/utils/process_video/extract_blink.py
@@ -0,0 +1,50 @@
+import numpy as np
+from data_util.face3d_helper import Face3DHelper
+from utils.commons.tensor_utils import convert_to_tensor
+
+def polygon_area(x, y):
+ """
+ x: [T, K=6]
+ y: [T, K=6]
+ return: [T,]
+ """
+ x_ = x - x.mean(axis=-1, keepdims=True)
+ y_ = y - y.mean(axis=-1, keepdims=True)
+ correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
+ main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
+ return 0.5 * np.abs(main_area + correction)
+
+def get_eye_area_percent(id, exp, face3d_helper):
+ id = convert_to_tensor(id)
+ exp = convert_to_tensor(exp)
+ cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
+ cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
+ lms = cano_lm2d.cpu().numpy()
+ eyes_left = slice(36, 42)
+ eyes_right = slice(42, 48)
+ area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
+ area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
+ # area percentage of two eyes of the whole image...
+ area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
+ return area_percent # [T,]
+
+
+if __name__ == '__main__':
+ import numpy as np
+ import imageio
+ import cv2
+ import torch
+ from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+ from data_util.face3d_helper import Face3DHelper
+
+ face3d_helper = Face3DHelper()
+ video_name = 'data/raw/videos/May_10s.mp4'
+ frames = read_video_to_frames(video_name)
+ coeff = fit_3dmm_for_a_video(video_name, save=False)
+ area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
+ writer = imageio.get_writer("1.mp4", fps=25)
+ for idx, frame in enumerate(frames):
+ frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
+ writer.append_data(frame)
+ writer.close()
\ No newline at end of file
diff --git a/data_gen/utils/process_video/extract_lm2d.py b/data_gen/utils/process_video/extract_lm2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ae0b13408c1d837af4c40912ffc58e0043469b
--- /dev/null
+++ b/data_gen/utils/process_video/extract_lm2d.py
@@ -0,0 +1,164 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import sys
+import glob
+import cv2
+import pickle
+import tqdm
+import numpy as np
+import mediapipe as mp
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from utils.commons.os_utils import multiprocess_glob
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+import warnings
+import traceback
+
+warnings.filterwarnings('ignore')
+
+"""
+基于Face_aligment的lm68已被弃用,因为其:
+1. 对眼睛部位的预测精度极低
+2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
+我们目前转而使用基于mediapipe的lm68
+"""
+# def extract_landmarks(ori_imgs_dir):
+
+# print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
+
+# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
+# image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
+# for image_path in tqdm.tqdm(image_paths):
+# out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
+# if os.path.exists(out_name):
+# continue
+# input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
+# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
+# preds = fa.get_landmarks(input)
+# if preds is None:
+# print(f"Skip {image_path} for no face detected")
+# continue
+# if len(preds) > 0:
+# lands = preds[0].reshape(-1, 2)[:,:2]
+# os.makedirs(os.path.dirname(out_name), exist_ok=True)
+# np.savetxt(out_name, lands, '%f')
+# del fa
+# print(f'[INFO] ===== extracted face landmarks =====')
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+
+face_landmarker = None
+
+def extract_landmark_job(video_name, nerf=False):
+ try:
+ if nerf:
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
+ else:
+ out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name):
+ # print("out exists, skip...")
+ return
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+ global face_landmarker
+ if face_landmarker is None:
+ face_landmarker = MediapipeLandmarker()
+ img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
+ lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
+ np.save(out_name, lm478)
+ return True
+ # print("Hahaha, solve one item!!!")
+ except Exception as e:
+ traceback.print_exc()
+ return False
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ if len(vid_names) == 1: # nerf
+ return vid_names
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='nerf')
+ parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
+ parser.add_argument("--num_workers", default=2, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action="store_true")
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ if not load_names:
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+ out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names)
+ print(f"todo videos number: {len(vid_names)}")
+
+ fail_cnt = 0
+ job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
+ for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
+ if res is False:
+ fail_cnt += 1
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
+ sys.stdout.flush()
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/extract_segment_imgs.py b/data_gen/utils/process_video/extract_segment_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..868773c4b323897c6654834ed09b87820b1dc7b4
--- /dev/null
+++ b/data_gen/utils/process_video/extract_segment_imgs.py
@@ -0,0 +1,500 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import random
+import glob
+import cv2
+import tqdm
+import numpy as np
+import PIL
+from utils.commons.tensor_utils import convert_to_np
+from utils.commons.os_utils import multiprocess_glob
+import pickle
+import torch
+import mediapipe as mp
+import traceback
+import multiprocessing
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from scipy.ndimage import binary_erosion, binary_dilation
+from sklearn.neighbors import NearestNeighbors
+from mediapipe.tasks.python import vision
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image
+
+seg_model = None
+segmenter = None
+mat_model = None
+lama_model = None
+lama_config = None
+
+from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
+
+BG_NAME_MAP = {
+ "knn": "",
+ "mat": "_mat",
+ "ddnm": "_ddnm",
+ "lama": "_lama",
+}
+FRAME_SELECT_INTERVAL = 5
+SIM_METHOD = "mse"
+SIM_THRESHOLD = 3
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def save_rgb_alpha_image_to_path(img, alpha, img_path):
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
+ except: pass
+ cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
+
+def save_rgb_image_to_path(img, img_path):
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
+ except: pass
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+def load_rgb_image_to_path(img_path):
+ return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
+
+def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
+ if method == "mse":
+ return np.mean((x - y) ** 2)
+ else:
+ raise NotImplementedError
+
+def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
+ """
+ img_lst: list of rgb ndarray
+ method: "knn", "mat" or "ddnm"
+ """
+ # only use 1/20 images
+ global segmenter
+ global seg_model
+ global mat_model
+ global lama_model
+ global lama_config
+
+ assert len(img_lst) > 0
+ if segmap_mask_lst is not None:
+ assert len(segmap_mask_lst) == len(img_lst)
+ else:
+ del segmenter
+ del seg_model
+ seg_model = MediapipeSegmenter()
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
+
+ def get_segmap_mask(img_lst, segmap_mask_lst, index):
+ if segmap_mask_lst is not None:
+ segmap = segmap_mask_lst[index]
+ else:
+ segmap = seg_model._cal_seg_map(img_lst[index], segmenter=segmenter)
+ return segmap
+
+ if method == "knn":
+ num_frames = len(img_lst)
+ img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
+
+ if segmap_mask_lst is not None:
+ segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
+ assert len(img_lst) == len(segmap_mask_lst)
+ # get H/W
+ h, w = img_lst[0].shape[:2]
+
+ # nearest neighbors
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
+ distss = []
+ for idx, img in enumerate(img_lst):
+ segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
+ bg = (segmap[0]).astype(bool) # [h,w] bool mask
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
+ distss.append(dists)
+
+ distss = np.stack(distss) # [B, 512*512, 1]
+ max_dist = np.max(distss, 0) # [512*512, 1]
+ max_id = np.argmax(distss, 0) # id of frame
+
+ bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
+ bc_pixs_id = np.nonzero(bc_pixs)
+ bc_ids = max_id[bc_pixs]
+
+ num_pixs = distss.shape[1]
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
+
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
+ bg_img = bg_img.reshape(h, w, 3)
+
+ max_dist = max_dist.reshape(h, w)
+ bc_pixs = max_dist > 10 # 5
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,用KNN找最近的bg pixel
+ bg_fg_xys = fg_xys[indices[:, 0]]
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ else:
+ raise NotImplementedError # deperated
+
+ return bg_img
+
+def inpaint_torso_job(gt_img, segmap):
+ bg_part = (segmap[0]).astype(bool)
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
+ neck_part = (segmap[2]).astype(bool)
+ torso_part = (segmap[4]).astype(bool)
+ img = gt_img.copy()
+ img[head_part] = 0
+
+ # torso part "vertical" in-painting...
+ L = 8 + 1
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
+ torso_coords = torso_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
+ top_torso_coords = torso_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
+ mask = head_part[tuple(top_torso_coords_up.T)]
+ if mask.any():
+ top_torso_coords = top_torso_coords[mask]
+ # get the color
+ top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_torso_coords += inpaint_offsets
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
+ inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
+ else:
+ inpaint_torso_mask = None
+
+ # neck part "vertical" in-painting...
+ push_down = 4
+ L = 48 + push_down + 1
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
+ neck_coords = neck_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
+ top_neck_coords = neck_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
+ mask = head_part[tuple(top_neck_coords_up.T)]
+ top_neck_coords = top_neck_coords[mask]
+ # push these top down for 4 pixels to make the neck inpainting more natural...
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
+ # get the color
+ top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_neck_coords += inpaint_offsets
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
+ inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
+
+ blur_img = img.copy()
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
+ img[inpaint_mask] = blur_img[inpaint_mask]
+
+ # set mask
+ torso_img_mask = (neck_part | torso_part | inpaint_mask)
+ torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
+ if inpaint_torso_mask is not None:
+ torso_img_mask = torso_img_mask | inpaint_torso_mask
+ torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
+
+ torso_img = img.copy()
+ torso_img[~torso_img_mask] = 0
+ torso_with_bg_img = img.copy()
+ torso_img[~torso_with_bg_img_mask] = 0
+
+ return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
+
+
+def extract_segment_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
+ global segmenter
+ global seg_model
+ del segmenter
+ del seg_model
+ seg_model = MediapipeSegmenter()
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
+ try:
+ if "cuda" in device:
+ # determine which cuda index from subprocess id
+ pname = multiprocessing.current_process().name
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
+ cuda_id = pid % total_gpus
+ device = f"cuda:{cuda_id}"
+
+ if nerf: # single video
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
+ else: # whole dataset
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
+ if not os.path.exists(raw_img_dir):
+ extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
+
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+
+ img_lst = []
+
+ for img_name in img_names:
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_lst.append(img)
+
+ segmap_mask_lst, segmap_image_lst = seg_model._cal_seg_map_for_video(img_lst, segmenter=segmenter, return_onehot_mask=True, return_segmap_image=True)
+ del segmap_image_lst
+ # for i in range(len(img_lst)):
+ for i in tqdm.trange(len(img_lst), desc='generating segment images using segmaps...'):
+ img_name = img_names[i]
+ segmap = segmap_mask_lst[i]
+ img = img_lst[i]
+ out_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
+ except: pass
+ encoded_segmap = encode_segmap_mask_to_image(segmap)
+ save_rgb_image_to_path(encoded_segmap, out_img_name)
+
+ for mode in ['head', 'torso', 'person', 'bg']:
+ out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
+ mask = mask[0][..., None]
+ img_alpha[~mask] = 0
+ out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
+ save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
+
+ inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
+ img_alpha[~inpaint_torso_img_mask[..., None]] = 0
+ out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
+ save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
+
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
+ if nerf:
+ out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
+ else:
+ out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
+ save_rgb_image_to_path(bg_img, out_img_name)
+
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
+ for i, img_name in enumerate(img_names):
+ com_img = img_lst[i].copy()
+ segmap = segmap_mask_lst[i]
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
+ com_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+ save_rgb_image_to_path(com_img, out_img_name)
+ return 0
+ except Exception as e:
+ print(str(type(e)), e)
+ traceback.print_exc(e)
+ return 1
+
+# def check_bg_img_job_finished(raw_img_dir, bg_name, com_dir):
+# img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+# com_names = glob.glob(os.path.join(com_dir, "*.jpg"))
+# return len(img_names) == len(com_names) and os.path.exists(bg_name)
+
+# extract background and combined image
+# need pre-processed "gt_imgs" and "segmaps"
+def extract_bg_img_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
+ try:
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
+
+ if "cuda" in device:
+ # determine which cuda index from subprocess id
+ pname = multiprocessing.current_process().name
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
+ cuda_id = pid % total_gpus
+ device = f"cuda:{cuda_id}"
+
+ if nerf: # single video
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
+ else: # whole dataset
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
+ if nerf:
+ bg_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
+ else:
+ bg_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
+ # com_dir = raw_img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+ # if check_bg_img_job_finished(raw_img_dir=raw_img_dir, bg_name=bg_name, com_dir=com_dir):
+ # print(f"Already finished, skip {raw_img_dir} ")
+ # return 0
+
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+ img_lst = []
+ for img_name in img_names:
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_lst.append(img)
+
+ segmap_mask_lst = []
+ for img_name in img_names:
+ segmap_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png")
+ segmap_img = load_rgb_image_to_path(segmap_img_name)
+
+ segmap_mask = decode_segmap_mask_from_image(segmap_img)
+ segmap_mask_lst.append(segmap_mask)
+
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
+ save_rgb_image_to_path(bg_img, bg_name)
+
+ for i, img_name in enumerate(img_names):
+ com_img = img_lst[i].copy()
+ segmap = segmap_mask_lst[i]
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3, axis=-1)
+ com_img[bg_part] = bg_img[bg_part]
+ com_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+ save_rgb_image_to_path(com_img, com_name)
+ return 0
+
+ except Exception as e:
+ print(str(type(e)), e)
+ traceback.print_exc(e)
+ return 1
+
+def out_exist_job(vid_name, background_method='knn', only_bg_img=False):
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
+ img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
+ out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
+ out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+
+ if not only_bg_img:
+ if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
+ num_frames = len(os.listdir(img_dir))
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
+ return None
+ else:
+ return vid_name
+ else:
+ return vid_name
+ else:
+ if os.path.exists(img_dir) and os.path.exists(out_dir2):
+ num_frames = len(os.listdir(img_dir))
+ if len(os.listdir(out_dir2)) == num_frames:
+ return None
+ else:
+ return vid_name
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names, background_method='knn', only_bg_img=False):
+ if len(vid_names) == 1: # nerf
+ return vid_names
+ todo_vid_names = []
+ fn_args = [(vid_name, background_method, only_bg_img) for vid_name in vid_names]
+ for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=48, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+ parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
+ parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
+ parser.add_argument("--only_bg_img", action="store_true")
+ parser.add_argument("--no_mix_bg", action="store_true")
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+ background_method = args.background_method
+ total_gpus = args.total_gpus
+ only_bg_img = args.only_bg_img
+ mix_bg = not args.no_mix_bg
+
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
+ for d in devices[:total_gpus]:
+ os.system(f'pkill -f "voidgpu{d}"')
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+
+ vid_names = sorted(vid_names)
+ random.seed(args.seed)
+ random.shuffle(vid_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names, background_method, only_bg_img)
+ print(f"todo videos number: {len(vid_names)}")
+ # exit()
+
+ device = "cuda" if total_gpus > 0 else "cpu"
+ if only_bg_img:
+ extract_job = extract_bg_img_job
+ fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
+ else:
+ extract_job = extract_segment_job
+ fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
+
+ for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/fit_3dmm_landmark.py b/data_gen/utils/process_video/fit_3dmm_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..2622860f66989758fafc7f244b89b4f6da5f43f1
--- /dev/null
+++ b/data_gen/utils/process_video/fit_3dmm_landmark.py
@@ -0,0 +1,565 @@
+# This is a script for efficienct 3DMM coefficient extraction.
+# It could reconstruct accurate 3D face in real-time.
+# It is built upon BFM 2009 model and mediapipe landmark extractor.
+# It is authored by ZhenhuiYe (zhenhuiye@zju.edu.cn), free to contact him for any suggestion on improvement!
+
+from numpy.core.numeric import require
+from numpy.lib.function_base import quantile
+import torch
+import torch.nn.functional as F
+import copy
+import numpy as np
+
+import random
+import pickle
+import os
+import sys
+import cv2
+import argparse
+import tqdm
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from deep_3drecon.secc_renderer import SECC_Renderer
+from utils.commons.os_utils import multiprocess_glob
+
+
+face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
+face_model.to(torch.device("cuda:0"))
+
+dir_path = os.path.dirname(os.path.realpath(__file__))
+
+
+def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
+ # yaw = -yaw
+ pitch = - pitch
+ roll = - roll
+ rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
+ axes_points = np.array([
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0]
+ ], dtype=np.float64)
+ axes_points = rotation_matrix @ axes_points
+ axes_points = (axes_points[:2, :] * size).astype(int)
+ axes_points[0, :] = axes_points[0, :] + tx
+ axes_points[1, :] = axes_points[1, :] + ty
+
+ new_img = img.copy()
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
+ return new_img
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def cal_lap_loss(in_tensor):
+ # [T, 68, 2]
+ t = in_tensor.shape[0]
+ in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
+ in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
+ lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
+ loss_lap = 0
+
+ out_tensor = F.conv1d(in_tensor, lap_kernel)
+ loss_lap += torch.mean(out_tensor**2)
+ return loss_lap
+
+def cal_vel_loss(ldm):
+ # [B, 68, 2]
+ vel = ldm[1:] - ldm[:-1]
+ return torch.mean(torch.abs(vel))
+
+def cal_lan_loss(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan)** 2
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
+ weights = torch.zeros_like(loss)
+ weights = torch.ones_like(loss)
+ weights[:, 36:48, :] = 3 # eye 12 points
+ weights[:, -8:, :] = 3 # inner lip 8 points
+ weights[:, 28:31, :] = 3 # nose 3 points
+ loss = loss * weights
+ return torch.mean(loss)
+
+def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan).pow(2)
+ # loss = (proj_lan - gt_lan).abs()
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
+ upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+ weights = torch.ones_like(loss)
+ weights[:, eye] = 3
+ weights[:, upper_eye] = 20
+ weights[:, inner_lip] = 5
+ weights[:, outer_lip] = 5
+ weights[:, unmatch_mask] = 0
+ loss = loss * weights
+ if mean:
+ loss = torch.mean(loss)
+ return loss
+
+def cal_acceleration_loss(trans):
+ vel = trans[1:] - trans[:-1]
+ acc = vel[1:] - vel[:-1]
+ return torch.mean(torch.abs(acc))
+
+def cal_acceleration_ldm_loss(ldm):
+ # [B, 68, 2]
+ vel = ldm[1:] - ldm[:-1]
+ acc = vel[1:] - vel[:-1]
+ lip_weight = 0.25 # we dont want smooth the lip too much
+ acc[48:68] *= lip_weight
+ return torch.mean(torch.abs(acc))
+
+def set_requires_grad(tensor_list):
+ for tensor in tensor_list:
+ tensor.requires_grad = True
+
+@torch.enable_grad()
+def fit_3dmm_for_a_video(
+ video_name,
+ nerf=False, # use the file name convention for GeneFace++
+ id_mode='global',
+ debug=False,
+ keypoint_mode='mediapipe',
+ large_yaw_threshold=9999999.9,
+ save=True
+) -> bool: # True: good, False: bad
+ assert video_name.endswith(".mp4"), "this function only support video as input"
+ if id_mode == 'global':
+ LAMBDA_REG_ID = 0.2
+ LAMBDA_REG_EXP = 0.6
+ LAMBDA_REG_LAP = 1.0
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
+ else:
+ LAMBDA_REG_ID = 0.3
+ LAMBDA_REG_EXP = 0.05
+ LAMBDA_REG_LAP = 1.0
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
+
+ frames = read_video_to_frames(video_name) # [T, H, W, 3]
+ img_h, img_w = frames.shape[1], frames.shape[2]
+ assert img_h == img_w
+ num_frames = len(frames)
+
+ if nerf: # single video
+ lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
+ else:
+ lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
+
+ if os.path.exists(lm_name):
+ lms = np.load(lm_name)
+ else:
+ print(f"lms_2d file not found, try to extract it from video... {lm_name}")
+ try:
+ landmarker = MediapipeLandmarker()
+ img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
+ lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
+ except Exception as e:
+ print(e)
+ return False
+ if lms is None:
+ print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
+ return False
+ lms = lms[:, :468, :]
+ lms = torch.FloatTensor(lms).cuda()
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+
+ if keypoint_mode == 'mediapipe':
+ # default
+ cal_lan_loss_fn = cal_lan_loss_mp
+ if nerf: # single video
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
+ else:
+ out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
+ else:
+ # lm68 is less accurate than mp
+ cal_lan_loss_fn = cal_lan_loss
+ if nerf: # single video
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
+ else:
+ out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+
+ id_dim, exp_dim = 80, 64
+ sel_ids = np.arange(0, num_frames, 40)
+
+ h = w = face_model.center * 2
+ img_scale_factor = img_h / h
+ lms /= img_scale_factor # rescale lms into [0,224]
+
+ if id_mode == 'global':
+ # default choice by GeneFace++ and later works
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
+ elif id_mode == 'finegrained':
+ # legacy choice by GeneFace1 (ICLR 2023)
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
+
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
+
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
+
+ # 其他参数初始化,先训练euler和trans
+ for _ in range(200):
+ if id_mode == 'global':
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
+ else:
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
+ loss = loss_lan
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_frame.step()
+
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ for param_group in optimizer_frame.param_groups:
+ param_group['lr'] = 0.1
+
+ # "jointly roughly training id exp euler trans"
+ for _ in range(200):
+ ret = {}
+ if id_mode == 'global':
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
+ else:
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans, ret)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.detach())
+ # loss_lap = cal_lap_loss(proj_geo)
+ # laplacian对euler影响不大,但是对trans的提升很大
+ loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
+
+ loss_regid = torch.mean(id_para*id_para) # 正则化
+ loss_regexp = torch.mean(exp_para * exp_para)
+
+ loss_vel_id = cal_vel_loss(id_para)
+ loss_vel_exp = cal_vel_loss(exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
+ optimizer_idexp.zero_grad()
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_idexp.step()
+ optimizer_frame.step()
+
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ # start fine training, intialize from the roughly trained results
+ if id_mode == 'global':
+ id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
+ else:
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ id_para_.data = id_para.data.clone()
+ id_para = id_para_
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ exp_para_.data = exp_para.data.clone()
+ exp_para = exp_para_
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ euler_angle_.data = euler_angle.data.clone()
+ euler_angle = euler_angle_
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans_.data = trans.data.clone()
+ trans = trans_
+
+ batch_size = 50
+ # "fine fitting the 3DMM in batches"
+ for i in range(int((num_frames-1)/batch_size+1)):
+ if (i+1)*batch_size > num_frames:
+ start_n = num_frames-batch_size
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
+ else:
+ start_n = i*batch_size
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
+ sel_lms = lms[sel_ids]
+
+ if id_mode == 'global':
+ sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
+ else:
+ sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
+ sel_id_para.data = id_para[sel_ids].clone()
+ sel_exp_para = exp_para.new_zeros(
+ (batch_size, exp_dim), requires_grad=True)
+ sel_exp_para.data = exp_para[sel_ids].clone()
+ sel_euler_angle = euler_angle.new_zeros(
+ (batch_size, 3), requires_grad=True)
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
+ sel_trans.data = trans[sel_ids].clone()
+
+ if id_mode == 'global':
+ set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+ else:
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+
+ for j in range(50):
+ ret = {}
+ proj_geo = face_model.compute_for_landmark_fit(
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms[sel_ids].detach())
+
+ # loss_lap = cal_lap_loss(proj_geo)
+ loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
+ loss_vel_id = cal_vel_loss(sel_id_para)
+ loss_vel_exp = cal_vel_loss(sel_exp_para)
+ log_dict = {
+ 'loss_vel_id': loss_vel_id,
+ 'loss_vel_exp': loss_vel_exp,
+ 'loss_vel_euler': cal_vel_loss(sel_euler_angle),
+ 'loss_vel_trans': cal_vel_loss(sel_trans),
+ }
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
+
+ optimizer_cur_batch.zero_grad()
+ loss.backward()
+ optimizer_cur_batch.step()
+
+ if debug:
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
+ print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
+ if id_mode != 'global':
+ id_para[sel_ids].data = sel_id_para.data.clone()
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
+ trans[sel_ids].data = sel_trans.data.clone()
+
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
+
+ # filter data by side-view pose
+ # bad_yaw = False
+ # yaws = [] # not so accurate
+ # for index in range(coeff_dict["trans"].shape[0]):
+ # yaw = coeff_dict["euler"][index][1]
+ # yaw = np.abs(yaw)
+ # yaws.append(yaw)
+ # if yaw > large_yaw_threshold:
+ # bad_yaw = True
+
+ if debug:
+ import imageio
+ from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
+ from data_util.face3d_helper import Face3DHelper
+ from data_gen.utils.process_video.extract_blink import get_eye_area_percent
+ face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
+
+ t = coeff_dict['exp'].shape[0]
+ if len(coeff_dict['id']) == 1:
+ coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
+ cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
+ cano_lm3d = cano_lm3d.reshape([t, -1, 3])
+ WH = 512
+ cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
+
+ with torch.no_grad():
+ rot = ParametricFaceModel.compute_rotation(euler_angle)
+ extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
+ extrinsic[:, :3,:3] = rot
+ extrinsic[:, :3, 3] = trans # / 10
+ extrinsic[:, 3, 3] = 1
+ extrinsic = extrinsic.cpu().numpy()
+
+ xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
+ xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
+
+ if nerf:
+ debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
+ else:
+ debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
+ try:
+ os.makedirs(os.path.dirname(debug_name), exist_ok=True)
+ except: pass
+ writer = imageio.get_writer(debug_name, fps=25)
+ if id_mode == 'global':
+ id_para = id_para.repeat([exp_para.shape[0], 1])
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
+ lm68s = lm68s * img_scale_factor
+ lms = lms * img_scale_factor
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+ lm68s = lm68s.astype(int)
+ for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
+ xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
+ xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
+ xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
+ xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
+
+ img = copy.deepcopy(frames[i])
+ img2 = copy.deepcopy(frames[i])
+
+ img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
+
+ gt_lm_color = (255, 0, 0)
+
+ for lm in lm68s[i]:
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
+ for gt_lm in lms[i]:
+ img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
+
+ cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+ for j in range(len(cano_lm3d[i])):
+ x, y, _ = cano_lm3d[i, j]
+ color = (255,0,0)
+ cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
+ cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
+
+ _, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
+ secc_img = (secc_img +1)*127.5
+ secc_img = F.interpolate(secc_img, size=(img_h, img_w))
+ secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
+ out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
+ out_img = np.concatenate([out_img1, out_img2], axis=0)
+ writer.append_data(out_img)
+ writer.close()
+
+ # if bad_yaw:
+ # print(f"Skip {video_name} due to TOO LARGE YAW")
+ # return False
+
+ if save:
+ np.save(out_name, coeff_dict, allow_pickle=True)
+ return coeff_dict
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
+ lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name) or not os.path.exists(lms_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ if len(vid_names) == 1: # single video, nerf
+ return vid_names
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
+ parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
+ parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
+ parser.add_argument("--debug", action='store_true')
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+
+ print(f"args {args}")
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+ out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
+
+ print(vid_names[:10])
+ random.seed(args.seed)
+ random.shuffle(vid_names)
+
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
+ face_model.to(torch.device("cuda:0"))
+ secc_renderer = SECC_Renderer(512)
+ secc_renderer.to("cuda:0")
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names)
+
+ failed_img_names = []
+ for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
+ img_name = vid_names[i]
+ try:
+ is_person_specific_data = ds_name=='nerf'
+ success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
+ if not success:
+ failed_img_names.append(img_name)
+ except Exception as e:
+ print(img_name, e)
+ failed_img_names.append(img_name)
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
+ sys.stdout.flush()
+ print(f"all failed image names: {failed_img_names}")
+ print(f"All finished!")
\ No newline at end of file
diff --git a/data_gen/utils/process_video/inpaint_torso_imgs.py b/data_gen/utils/process_video/inpaint_torso_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c938a6f79e7b796cc6f321e332eb7840244b4cf9
--- /dev/null
+++ b/data_gen/utils/process_video/inpaint_torso_imgs.py
@@ -0,0 +1,193 @@
+import cv2
+import os
+import numpy as np
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from scipy.ndimage import binary_erosion, binary_dilation
+
+from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
+seg_model = MediapipeSegmenter()
+
+def inpaint_torso_job(video_name, idx=None, total=None):
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+
+ for image_path in tqdm.tqdm(img_names):
+ # read ori image
+ ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
+ torso_part = (segmap[4]).astype(np.bool)
+ neck_part = (segmap[2]).astype(np.bool)
+ bg_part = segmap[0].astype(np.bool)
+ head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+
+ # head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
+ # torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
+ # bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
+
+ # get gt image
+ gt_image = ori_image.copy()
+ gt_image[bg_part] = bg_image[bg_part]
+ cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
+
+ # get torso image
+ torso_image = gt_image.copy() # rgb
+ torso_image[head_part] = 0
+ torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
+
+ # torso part "vertical" in-painting...
+ L = 8 + 1
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
+ torso_coords = torso_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
+ top_torso_coords = torso_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
+ mask = head_part[tuple(top_torso_coords_up.T)]
+ if mask.any():
+ top_torso_coords = top_torso_coords[mask]
+ # get the color
+ top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_torso_coords += inpaint_offsets
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
+
+ inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
+ else:
+ inpaint_torso_mask = None
+
+ # neck part "vertical" in-painting...
+ push_down = 4
+ L = 48 + push_down + 1
+
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
+
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
+ neck_coords = neck_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
+ top_neck_coords = neck_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
+ mask = head_part[tuple(top_neck_coords_up.T)]
+
+ top_neck_coords = top_neck_coords[mask]
+ # push these top down for 4 pixels to make the neck inpainting more natural...
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
+ # get the color
+ top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_neck_coords += inpaint_offsets
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
+
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
+ inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
+
+ blur_img = torso_image.copy()
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
+
+ torso_image[inpaint_mask] = blur_img[inpaint_mask]
+
+ # set mask
+ mask = (neck_part | torso_part | inpaint_mask)
+ if inpaint_torso_mask is not None:
+ mask = mask | inpaint_torso_mask
+ torso_image[~mask] = 0
+ torso_alpha[~mask] = 0
+
+ cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
+
+ print(f'[INFO] ===== extracted torso and gt images =====')
+
+
+def out_exist_job(vid_name):
+ out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
+ out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
+ out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
+ out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
+
+ if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
+ num_frames = len(os.listdir(out_dir1))
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
+ return None
+ else:
+ return vid_name
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=48, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+
+ inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
+ # args = parser.parse_args()
+ # vid_dir = args.vid_dir
+ # ds_name = args.ds_name
+ # if ds_name in ['lrs3_trainval']:
+ # mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ # if ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ # vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
+ # elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ # vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ # vid_names = glob.glob(vid_name_pattern)
+ # vid_names = sorted(vid_names)
+ # random.seed(args.seed)
+ # random.shuffle(vid_names)
+
+ # process_id = args.process_id
+ # total_process = args.total_process
+ # if total_process > 1:
+ # assert process_id <= total_process -1
+ # num_samples_per_process = len(vid_names) // total_process
+ # if process_id == total_process:
+ # vid_names = vid_names[process_id * num_samples_per_process : ]
+ # else:
+ # vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ # if not args.reset:
+ # vid_names = get_todo_vid_names(vid_names)
+ # print(f"todo videos number: {len(vid_names)}")
+
+ # fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
+ # for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
+ # pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py b/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
new file mode 100644
index 0000000000000000000000000000000000000000..f01c1681a8e39046645cfdb3e5d79b4b82cf9b46
--- /dev/null
+++ b/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
@@ -0,0 +1,87 @@
+import os, glob
+import cv2
+from utils.commons.os_utils import multiprocess_glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+def get_video_infos(video_path):
+ vid_cap = cv2.VideoCapture(video_path)
+ height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
+ total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
+
+def extract_img_job(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ vid_info = get_video_infos(video_name)
+ assert vid_info['width'] == vid_info['height']
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+def extract_img_job_crop(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ vid_info = get_video_infos(video_name)
+ wh = min(vid_info['width'], vid_info['height'])
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+def extract_img_job_crop_ravdess(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=32, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ print(f"args {args}")
+
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ if ds_name in ['lrs3_trainval']:
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ vid_names = multiprocess_glob(vid_name_pattern)
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ vid_names = multiprocess_glob(vid_name_pattern)
+ else:
+ raise NotImplementedError()
+ vid_names = sorted(vid_names)
+ print(f"total video number : {len(vid_names)}")
+ print(f"first {vid_names[0]} last {vid_names[-1]}")
+ # exit()
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if ds_name == "RAVDESS":
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+ elif ds_name == "CMLR":
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+ else:
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+
diff --git a/data_gen/utils/process_video/split_video_to_imgs.py b/data_gen/utils/process_video/split_video_to_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1c16c3415fb953c965cf56b3161a59460375079
--- /dev/null
+++ b/data_gen/utils/process_video/split_video_to_imgs.py
@@ -0,0 +1,53 @@
+import os, glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+from data_gen.utils.path_converter import PathConverter, pc
+
+# mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
+
+def extract_img_job(video_name, raw_img_dir=None):
+ if raw_img_dir is not None:
+ out_path = raw_img_dir
+ else:
+ out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
+ os.makedirs(out_path, exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ if ds_name in ['lrs3_trainval']:
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ vid_names = glob.glob(vid_name_pattern)
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ vid_names = glob.glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
+ pass
+
diff --git a/data_util/face3d_helper.py b/data_util/face3d_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4d260b1a9c320639ad035bb1d804ed21b076092
--- /dev/null
+++ b/data_util/face3d_helper.py
@@ -0,0 +1,309 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+from scipy.io import loadmat
+
+from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
+
+
+class Face3DHelper(nn.Module):
+ def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
+ super().__init__()
+ self.keypoint_mode = keypoint_mode # lm68 | mediapipe
+ self.bfm_dir = bfm_dir
+ self.load_3dmm()
+ if use_gpu: self.to("cuda")
+
+ def load_3dmm(self):
+ model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
+ self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ # re-center
+ mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+ self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
+ self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
+
+ self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
+ self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
+
+ self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
+ self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
+ if self.keypoint_mode == 'mediapipe':
+ self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
+ unmatch_mask = self.key_points < 0
+ self.key_points[unmatch_mask] = 0
+ else:
+ self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
+
+
+ self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
+ self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
+ self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
+ self.key_id_base_np = self.key_id_base.cpu().numpy()
+ self.key_exp_base_np = self.key_exp_base.cpu().numpy()
+
+ self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
+ def split_coeff(self, coeff):
+ """
+ coeff: Tensor[B, T, c=257] or [T, c=257]
+ """
+ ret_dict = {
+ 'identity': coeff[..., :80], # identity, [b, t, c=80]
+ 'expression': coeff[..., 80:144], # expression, [b, t, c=80]
+ 'texture': coeff[..., 144:224], # texture, [b, t, c=80]
+ 'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
+ 'translation': coeff[..., 254:257], # translation, [b, t, c=3]
+ 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
+ }
+ return ret_dict
+
+ def reconstruct_face_mesh(self, id_coeff, exp_coeff):
+ """
+ Generate a pose-independent 3D face mesh!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
+ id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ # mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
+ # face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
+ return face
+
+ def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ # mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
+ # lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
+ return face
+
+ def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ rot = self.compute_rotation(euler)
+ # transform
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
+ # to camera
+ if to_camera:
+ lm3d[...,-1] = 10 - lm3d[...,-1]
+ return lm3d
+
+ def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
+ lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
+ lm2d[..., 0] = 1 - lm2d[..., 0]
+ lm2d[..., 1] = 1 - lm2d[..., 1]
+ return lm2d
+
+ def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ is_btc_flag = True if id_coeff.ndim == 3 else False
+ if is_btc_flag:
+ b,t,_ = id_coeff.shape
+ id_coeff = id_coeff.reshape([b*t,-1])
+ exp_coeff = exp_coeff.reshape([b*t,-1])
+ euler = euler.reshape([b*t,-1])
+ trans = trans.reshape([b*t,-1])
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ rot = self.compute_rotation(euler)
+ # transform
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
+ # to camera
+ if to_camera:
+ lm3d[...,-1] = 10 - lm3d[...,-1]
+ # to image_plane
+ lm3d = lm3d @ self.persc_proj
+ lm2d = lm3d[..., :2] / lm3d[..., 2:]
+ # flip
+ lm2d[..., 1] = 224 - lm2d[..., 1]
+ lm2d /= 224
+ if is_btc_flag:
+ return lm2d.reshape([b,t,-1,2])
+ return lm2d
+
+ def compute_rotation(self, euler):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ euler -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = euler.shape[0]
+ euler = euler.to(self.key_id_base.device)
+ ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
+ x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+ def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ lm3d = face * 10
+ return lm3d
+
+ def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
+ identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ lm3d = face * 10
+ return lm3d
+
+ def get_eye_mouth_lm_from_lm3d(self, lm3d):
+ eye_lm = lm3d[:, 17:48] # [T, 31, 3]
+ mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
+ return eye_lm, mouth_lm
+
+ def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
+ eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
+ mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
+ return eye_lm, mouth_lm
+
+ def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
+ num_frames = idexp_lm3d.shape[0]
+ eps = 0.0
+ # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
+ idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
+ idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
+
+ idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
+ idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
+
+ idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
+ idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
+
+ if freeze_as_first_frame:
+ idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
+ return idexp_lm3d.cpu()
+
+ def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
+ eps = 0.003
+ idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
+ idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
+
+ idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
+ idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
+
+ return idexp_lm3d
+
+if __name__ == '__main__':
+ import cv2
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+
+ face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
+ coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
+ coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
+ lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
+
+ WH = 512
+ lm3d = (lm3d * WH).cpu().int().numpy()
+ eye_idx = list(range(36,48))
+ mouth_idx = list(range(48,68))
+ import imageio
+ debug_name = 'debug_lm3d.mp4'
+ writer = imageio.get_writer(debug_name, fps=25)
+ for i_img in range(len(lm3d)):
+ lm2d = lm3d[i_img ,:, :2] # [68, 2]
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ if i in eye_idx:
+ color = (0,0,255)
+ elif i in mouth_idx:
+ color = (0,255,0)
+ else:
+ color = (255,0,0)
+ img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
+ img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
+ writer.append_data(img)
+ writer.close()
diff --git a/deep_3drecon/BFM/.gitkeep b/deep_3drecon/BFM/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/bfm_left_eye_faces.npy b/deep_3drecon/bfm_left_eye_faces.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7044bb788d7f382888649a1b138912be259bbd78
--- /dev/null
+++ b/deep_3drecon/bfm_left_eye_faces.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
+size 4680
diff --git a/deep_3drecon/bfm_right_eye_faces.npy b/deep_3drecon/bfm_right_eye_faces.npy
new file mode 100644
index 0000000000000000000000000000000000000000..b995860e0c2021a548c413e5add0976f4dc34db7
--- /dev/null
+++ b/deep_3drecon/bfm_right_eye_faces.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
+size 4648
diff --git a/deep_3drecon/deep_3drecon_models/bfm.py b/deep_3drecon/deep_3drecon_models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce2cb08ba673a7d7e7c5db11dc2b394aca879ebb
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/bfm.py
@@ -0,0 +1,426 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+import os
+# from utils.commons.tensor_utils import convert_like
+
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([
+ focal, 0, center,
+ 0, focal, center,
+ 0, 0, 1
+ ]).reshape([3, 3]).astype(np.float32).transpose() # 注意这里的transpose!
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+ def __init__(self,
+ bfm_folder='./BFM',
+ recenter=True,
+ camera_distance=10.,
+ init_lit=np.array([
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
+ ]),
+ focal=1015.,
+ center=112.,
+ is_train=True,
+ default_name='BFM_model_front.mat',
+ keypoint_mode='mediapipe'):
+
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model['meanshape'].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model['idBase'].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model['exBase'].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model['meantex'].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model['texBase'].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model['tri'].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ if keypoint_mode == 'mediapipe':
+ self.keypoints = np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)
+ unmatch_mask = self.keypoints < 0
+ self.keypoints[unmatch_mask] = 0
+ else:
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model['skinmask'])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.key_mean_shape = self.mean_shape.reshape([-1, 3])[self.keypoints, :].reshape([-1, 3])
+ self.key_id_base = self.id_base.reshape([-1, 3,80])[self.keypoints, :].reshape([-1, 80])
+ self.key_exp_base = self.exp_base.reshape([-1, 3, 64])[self.keypoints, :].reshape([-1, 64])
+
+ self.focal = focal
+ self.center = center
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = 'cpu'
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+ self.initialized = False
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+ self.initialized = True
+ return self
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+ def compute_key_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.key_id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.key_exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.key_mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.
+ return face_texture.reshape([batch_size, -1, 3])
+
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat([
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
+ ], dim=-1)
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+ @staticmethod
+ def compute_rotation(angles, device='cpu'):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ angles = angles.to(device)
+ ones = torch.ones([batch_size, 1]).to(device)
+ zeros = torch.zeros([batch_size, 1]).to(device)
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1] # reverse the depth axis, add a fixed offset of length
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, si≥ze (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ rotation = self.compute_rotation(coef_dict['angle'], device=self.device)
+
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+ def compute_face_vertex(self, id, exp, angle, trans):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ if not self.initialized:
+ self.to(id.device)
+ face_shape = self.compute_shape(id, exp)
+ rotation = self.compute_rotation(angle, device=self.device)
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+ return face_vertex
+
+ def compute_for_landmark_fit(self, id, exp, angles, trans, ret=None):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_key_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = face_proj
+ return landmark
+
+ def compute_for_landmark_fit_nerf(self, id, exp, angles, trans, ret=None):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_key_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = face_shape_transformed # no to_camera
+
+ face_proj = self.to_image(face_vertex)
+ landmark = face_proj
+ return landmark
+
+ # def compute_for_landmark_fit(self, id, exp, angles, trans, ret={}):
+ # """
+ # Return:
+ # face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ # face_color -- torch.tensor, size (B, N, 3), in RGB order
+ # landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ # Parameters:
+ # coeffs -- torch.tensor, size (B, 257)
+ # """
+ # face_shape = self.compute_shape(id, exp)
+ # rotation = self.compute_rotation(angles)
+
+ # face_shape_transformed = self.transform(face_shape, rotation, trans)
+ # face_vertex = self.to_camera(face_shape_transformed)
+
+ # face_proj = self.to_image(face_vertex)
+ # landmark = self.get_landmarks(face_proj)
+ # return landmark
+
+ def compute_for_render_fit(self, id, exp, angles, trans, tex, gamma):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(tex)
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, gamma)
+
+ return face_color, face_vertex, landmark
\ No newline at end of file
diff --git a/deep_3drecon/ncc_code.npy b/deep_3drecon/ncc_code.npy
new file mode 100644
index 0000000000000000000000000000000000000000..79568a9ce3c7a903cea7ec76f1870f15fd052f13
--- /dev/null
+++ b/deep_3drecon/ncc_code.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da54a620c0981d43cc9f30b3d8b3f5d4beb0ec0e27127a1ef3fb62ea50913609
+size 428636
diff --git a/deep_3drecon/secc_renderer.py b/deep_3drecon/secc_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d6b3cdc4051c1ad9ed98b3228a35f35d573ab7c
--- /dev/null
+++ b/deep_3drecon/secc_renderer.py
@@ -0,0 +1,78 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from deep_3drecon.util.mesh_renderer import MeshRenderer
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+class SECC_Renderer(nn.Module):
+ def __init__(self, rasterize_size=None, device="cuda"):
+ super().__init__()
+ self.face_model = ParametricFaceModel('deep_3drecon/BFM')
+ self.fov = 2 * np.arctan(self.face_model.center / self.face_model.focal) * 180 / np.pi
+ self.znear = 5.
+ self.zfar = 15.
+ if rasterize_size is None:
+ rasterize_size = 2*self.face_model.center
+ self.face_renderer = MeshRenderer(rasterize_fov=self.fov, znear=self.znear, zfar=self.zfar, rasterize_size=rasterize_size, use_opengl=False).cuda()
+ face_feat = np.load("deep_3drecon/ncc_code.npy", allow_pickle=True)
+ self.face_feat = torch.tensor(face_feat.T).unsqueeze(0).to(device=device)
+
+ del_index_re = np.load('deep_3drecon/bfm_right_eye_faces.npy')
+ del_index_re = del_index_re - 1
+ del_index_le = np.load('deep_3drecon/bfm_left_eye_faces.npy')
+ del_index_le = del_index_le - 1
+ face_buf_list = []
+ for i in range(self.face_model.face_buf.shape[0]):
+ if i not in del_index_re and i not in del_index_le:
+ face_buf_list.append(self.face_model.face_buf[i])
+ face_buf_arr = np.array(face_buf_list)
+ self.face_buf = torch.tensor(face_buf_arr).to(device=device)
+
+ def forward(self, id, exp, euler, trans):
+ """
+ id, exp, euler, euler: [B, C] or [B, T, C]
+ return:
+ MASK: [B, 1, 512, 512], value[0. or 1.0], 1.0 denotes is face
+ SECC MAP: [B, 3, 512, 512], value[0~1]
+ if input is BTC format, return [B, C, T, H, W]
+ """
+ bs = id.shape[0]
+ is_btc_flag = id.ndim == 3
+ if is_btc_flag:
+ t = id.shape[1]
+ bs = bs * t
+ id, exp, euler, trans = id.reshape([bs,-1]), exp.reshape([bs,-1]), euler.reshape([bs,-1]), trans.reshape([bs,-1])
+
+ face_vertex = self.face_model.compute_face_vertex(id, exp, euler, trans)
+ face_mask, _, secc_face = self.face_renderer(
+ face_vertex, self.face_buf.unsqueeze(0).repeat([bs, 1, 1]), feat=self.face_feat.repeat([bs,1,1]))
+ secc_face = (secc_face - 0.5) / 0.5 # scale to -1~1
+
+ if is_btc_flag:
+ bs = bs // t
+ face_mask = rearrange(face_mask, "(n t) c h w -> n c t h w", n=bs, t=t)
+ secc_face = rearrange(secc_face, "(n t) c h w -> n c t h w", n=bs, t=t)
+ return face_mask, secc_face
+
+
+if __name__ == '__main__':
+ import imageio
+
+ renderer = SECC_Renderer(rasterize_size=512)
+ ret = np.load("data/processed/videos/May/vid_coeff_fit.npy", allow_pickle=True).tolist()
+ idx = 6
+ id = torch.tensor(ret['id']).cuda()[idx:idx+1]
+ exp = torch.tensor(ret['exp']).cuda()[idx:idx+1]
+ angle = torch.tensor(ret['euler']).cuda()[idx:idx+1]
+ trans = torch.tensor(ret['trans']).cuda()[idx:idx+1]
+ mask, secc = renderer(id, exp, angle*0, trans*0) # [1, 1, 512, 512], [1, 3, 512, 512]
+
+ out_mask = mask[0].permute(1,2,0)
+ out_mask = (out_mask * 127.5 + 127.5).int().cpu().numpy()
+ imageio.imwrite("out_mask.png", out_mask)
+ out_img = secc[0].permute(1,2,0)
+ out_img = (out_img * 127.5 + 127.5).int().cpu().numpy()
+ imageio.imwrite("out_secc.png", out_img)
\ No newline at end of file
diff --git a/deep_3drecon/util/mesh_renderer.py b/deep_3drecon/util/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6e765d706fb31cbe7f0b4403b492893ca32221
--- /dev/null
+++ b/deep_3drecon/util/mesh_renderer.py
@@ -0,0 +1,131 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+import traceback
+
+try:
+ import pytorch3d.ops
+ from pytorch3d.structures import Meshes
+ from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ DirectionalLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+ )
+except:
+ traceback.print_exc()
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+# return np.array([[n/x, 0, 0, 0],
+# [ 0, n/-x, 0, 0],
+# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+# [ 0, 0, -1, 0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+ def __init__(self,
+ rasterize_fov,
+ znear=0.1,
+ zfar=10,
+ rasterize_size=224,**args):
+ super(MeshRenderer, self).__init__()
+
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
+ self.rasterize_size = rasterize_size
+ self.fov = rasterize_fov
+ self.znear = znear
+ self.zfar = zfar
+
+ self.rasterizer = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, N ,C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ # ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 0] = -vertex[..., 0]
+
+
+ # vertex_ndc = vertex @ ndc_proj.t()
+ if self.rasterizer is None:
+ self.rasterizer = MeshRasterizer()
+ print("create rasterizer on device cuda:%d"%device.index)
+
+ # ranges = None
+ # if isinstance(tri, List) or len(tri.shape) == 3:
+ # vum = vertex_ndc.shape[1]
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ # for i in range(tri.shape[0]):
+ # tri[i] = tri[i] + i*vum
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ # tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+
+ # rasterize
+ cameras = FoVPerspectiveCameras(
+ device=device,
+ fov=self.fov,
+ znear=self.znear,
+ zfar=self.zfar,
+ )
+
+ raster_settings = RasterizationSettings(
+ image_size=rsize
+ )
+
+ # print(vertex.shape, tri.shape)
+ if tri.ndim == 2:
+ tri = tri.unsqueeze(0)
+ mesh = Meshes(vertex.contiguous()[...,:3], tri)
+
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+ rast_out = fragments.pix_to_face.squeeze(-1)
+ depth = fragments.zbuf
+
+ # render depth
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+
+ image = None
+ if feat is not None:
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+ fragments.bary_coords,
+ attributes)
+ # print(image.shape)
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
+
diff --git a/docs/prepare_env/install_guide-zh.md b/docs/prepare_env/install_guide-zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..26d23b3178c471f155c0fb84fc11c6dff78ab884
--- /dev/null
+++ b/docs/prepare_env/install_guide-zh.md
@@ -0,0 +1,35 @@
+# 环境配置
+[English Doc](./install_guide.md)
+
+本文档陈述了搭建Real3D-Portrait Python环境的步骤,我们使用了Conda来管理依赖。
+
+以下配置已在 A100/V100 + CUDA11.7 中进行了验证。
+
+
+# 1. 安装CUDA
+我们推荐安装CUDA `11.7`,其他CUDA版本(例如`10.2`、`12.x`)也可能有效。
+
+# 2. 安装Python依赖
+```
+cd
+source /bin/activate
+conda create -n real3dportrait python=3.9
+conda activate real3dportrait
+conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
+
+# 我们推荐安装torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# 从源代码安装,需要比较长的时间 (如果遇到各种time-out问题,建议使用代理)
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV安装
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # 使用mim来加速mmcv安装
+
+# 其他依赖项
+pip install -r docs/prepare_env/requirements.txt -v
+
+```
+
diff --git a/docs/prepare_env/install_guide.md b/docs/prepare_env/install_guide.md
new file mode 100644
index 0000000000000000000000000000000000000000..67f2df44022671e2710e1213bf3133dec89ca382
--- /dev/null
+++ b/docs/prepare_env/install_guide.md
@@ -0,0 +1,34 @@
+# Prepare the Environment
+[中文文档](./install_guide-zh.md)
+
+This guide is about building a python environment for Real3D-Portrait with Conda.
+
+The following installation process is verified in A100/V100 + CUDA11.7.
+
+
+# 1. Install CUDA
+ We recommend to install CUDA `11.7` (which is verified in various types of GPUs), but other CUDA versions (such as `10.2`, `12.x`) may also work well.
+
+# 2. Install Python Packages
+```
+cd
+source /bin/activate
+conda create -n real3dportrait python=3.9
+conda activate real3dportrait
+conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
+
+### We recommend torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV for some network structure
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
+
+# other dependencies
+pip install -r docs/prepare_env/requirements.txt -v
+
+```
diff --git a/docs/prepare_env/requirements.txt b/docs/prepare_env/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d521e594fff966039a7b48de47367fbc11d2c69e
--- /dev/null
+++ b/docs/prepare_env/requirements.txt
@@ -0,0 +1,75 @@
+Cython
+numpy # ==1.23.0
+numba==0.56.4
+pandas
+transformers
+scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
+scikit-learn
+scikit-image
+# tensorflow # you can flexible it, this is gpu version
+tensorboard
+tensorboardX
+python_speech_features
+resampy
+opencv_python
+face_alignment
+matplotlib
+configargparse
+librosa==0.9.2
+praat-parselmouth # ==0.4.3
+trimesh
+kornia==0.5.0
+PyMCubes
+lpips
+setuptools # ==59.5.0
+ffmpeg-python
+moviepy
+dearpygui
+ninja
+# pyaudio # for extract esperanto
+mediapipe
+protobuf
+decord
+soundfile
+pillow
+# torch # it's better to install torch with conda
+av
+timm
+pretrainedmodels
+faiss-cpu # for fast nearest camera pose retriveal
+einops
+# mmcv # use mim install is faster
+
+# conditional flow matching
+beartype
+torchode
+torchdiffeq
+
+# tts
+cython
+textgrid
+pyloudnorm
+websocket-client
+pyworld==0.2.1rc0
+pypinyin==0.42.0
+webrtcvad
+torchshow
+
+# cal spk sim
+s3prl
+fire
+
+# cal LMD
+dlib
+
+# debug
+ipykernel
+
+# lama
+hydra-core
+pytorch_lightning
+setproctitle
+
+# Gradio GUI
+httpx==0.23.3
+gradio==4.16.0
\ No newline at end of file
diff --git a/inference/app_real3dportrait.py b/inference/app_real3dportrait.py
new file mode 100644
index 0000000000000000000000000000000000000000..b87dd7c8b50fcb287fa68a725a6dc293793bcf62
--- /dev/null
+++ b/inference/app_real3dportrait.py
@@ -0,0 +1,244 @@
+import os, sys
+import argparse
+import gradio as gr
+from inference.real3d_infer import GeneFace2Infer
+from utils.commons.hparams import hparams
+
+class Inferer(GeneFace2Infer):
+ def infer_once_args(self, *args, **kargs):
+ assert len(kargs) == 0
+ keys = [
+ 'src_image_name',
+ 'drv_audio_name',
+ 'drv_pose_name',
+ 'bg_image_name',
+ 'blink_mode',
+ 'temperature',
+ 'mouth_amp',
+ 'out_mode',
+ 'map_to_init_pose',
+ 'hold_eye_opened',
+ 'head_torso_threshold',
+ 'a2m_ckpt',
+ 'head_ckpt',
+ 'torso_ckpt',
+ ]
+ inp = {}
+ out_name = None
+ info = ""
+
+ try: # try to catch errors and jump to return
+ for key_index in range(len(keys)):
+ key = keys[key_index]
+ inp[key] = args[key_index]
+ if '_name' in key:
+ inp[key] = inp[key] if inp[key] is not None else ''
+
+ if inp['src_image_name'] == '':
+ info = "Input Error: Source image is REQUIRED!"
+ raise ValueError
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
+ raise ValueError
+
+
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
+ inp['drv_audio_name'] = inp['drv_pose_name']
+ print("No audio input, we use driving pose video for video driving")
+
+ if inp['drv_pose_name'] == '':
+ inp['drv_pose_name'] = 'static'
+
+ reload_flag = False
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
+ print("Changes of a2m_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['head_ckpt'] != self.head_model_dir:
+ print("Changes of head_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['torso_ckpt'] != self.torso_model_dir:
+ print("Changes of torso_ckpt detected, reloading model")
+ reload_flag = True
+
+ inp['out_name'] = ''
+ inp['seed'] = 42
+
+ print(f"infer inputs : {inp}")
+ if self.secc2video_hparams['htbsr_head_threshold'] != inp['head_torso_threshold']:
+ print("Changes of head_torso_threshold detected, reloading model")
+ reload_flag = True
+
+ try:
+ if reload_flag:
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Reload ERROR: {content}"
+ raise ValueError
+ try:
+ out_name = self.infer_once(inp)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Inference ERROR: {content}"
+ raise ValueError
+ except Exception as e:
+ if info == "": # unexpected errors
+ content = f"{e}"
+ info = f"WebUI ERROR: {content}"
+
+ # output part
+ if len(info) > 0 : # there is errors
+ print(info)
+ info_gr = gr.update(visible=True, value=info)
+ else: # no errors
+ info_gr = gr.update(visible=False, value=info)
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
+ print(f"Succefully generated in {out_name}")
+ video_gr = gr.update(visible=True, value=out_name)
+ else:
+ print(f"Failed to generate")
+ video_gr = gr.update(visible=True, value=out_name)
+
+ return video_gr, info_gr
+
+def toggle_audio_file(choice):
+ if choice == False:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+def ref_video_fn(path_of_ref_video):
+ if path_of_ref_video is not None:
+ return gr.update(value=True)
+ else:
+ return gr.update(value=False)
+
+def real3dportrait_demo(
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ device = 'cuda',
+ warpfn = None,
+ ):
+
+ sep_line = "-" * 40
+
+ infer_obj = Inferer(
+ audio2secc_dir=audio2secc_dir,
+ head_model_dir=head_model_dir,
+ torso_model_dir=torso_model_dir,
+ device=device,
+ )
+
+ print(sep_line)
+ print("Model loading is finished.")
+ print(sep_line)
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
+ gr.Markdown("\
+ Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight)
\
+
Arxiv \
+
Homepage \
+
Github ")
+
+ sources = None
+ with gr.Row():
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="source_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
+ with gr.Tabs(elem_id="driven_audio"):
+ with gr.TabItem('Upload audio'):
+ with gr.Column(variant='panel'):
+ drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
+ with gr.Tabs(elem_id="driven_pose"):
+ with gr.TabItem('Upload video'):
+ with gr.Column(variant='panel'):
+ drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
+ with gr.Tabs(elem_id="bg_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('General Settings'):
+ with gr.Column(variant='panel'):
+
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
+ mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
+ out_mode = gr.Radio(['final', 'concat_debug'], value='final', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose")
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
+ head_torso_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="head torso threshold", value=0.7, info='make it higher if you find ghosting around hair of output, default to be 0.7',)
+
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Tabs(elem_id="genearted_video"):
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('Checkpoints'):
+ with gr.Column(variant='panel'):
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
+ head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
+ torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
+
+
+ fn = infer_obj.infer_once_args
+ if warpfn:
+ fn = warpfn(fn)
+ submit.click(
+ fn=fn,
+ inputs=[
+ src_image_name,
+ drv_audio_name,
+ drv_pose_name,
+ bg_image_name,
+ blink_mode,
+ temperature,
+ mouth_amp,
+ out_mode,
+ map_to_init_pose,
+ hold_eye_opened,
+ head_torso_threshold,
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ ],
+ outputs=[
+ gen_video,
+ info_box,
+ ],
+ )
+
+ print(sep_line)
+ print("Gradio page is constructed.")
+ print(sep_line)
+
+ return real3dportrait_interface
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
+ parser.add_argument("--head_ckpt", type=str, default='')
+ parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
+ parser.add_argument("--port", type=int, default=None)
+ args = parser.parse_args()
+ demo = real3dportrait_demo(
+ audio2secc_dir=args.a2m_ckpt,
+ head_model_dir=args.head_ckpt,
+ torso_model_dir=args.torso_ckpt,
+ device='cuda:0',
+ warpfn=None,
+ )
+ demo.queue()
+ demo.launch(server_port=args.port)
+
diff --git a/inference/edit_secc.py b/inference/edit_secc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1e602b389665c2710eb76e5ab1244030096db2
--- /dev/null
+++ b/inference/edit_secc.py
@@ -0,0 +1,147 @@
+import cv2
+import torch
+from utils.commons.image_utils import dilate, erode
+from sklearn.neighbors import NearestNeighbors
+import copy
+import numpy as np
+from utils.commons.meters import Timer
+
+def hold_eye_opened_for_secc(img):
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ h,w = face_mask.shape
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+ opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
+ coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
+ # print(dists.max())
+ non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
+ non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
+ opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
+ opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+
+ img[opened_eye_mask] = 0
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+# def hold_eye_opened_for_secc(img):
+# img = copy.copy(img)
+# eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+# eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
+# img[eye_mask] = -1
+# return img
+
+def blink_eye_for_secc(img, close_eye_percent=0.5):
+ """
+ secc_img: [3,h,w], tensor, -1~1
+ """
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
+ if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+ img = copy.deepcopy(img)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ h,w = face_mask.shape
+
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
+ coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ min_h = coarse_eye_xys[:, 0].min()
+ max_h = coarse_eye_xys[:, 0].max()
+ coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ left_min_w = coarse_left_eye_xys[:, 1].min()
+ left_max_w = coarse_left_eye_xys[:, 1].max()
+ coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ right_min_w = coarse_right_eye_xys[:, 1].min()
+ right_max_w = coarse_right_eye_xys[:, 1].max()
+
+ # 尽力较少需要考虑的face_xyz,以降低KNN的损耗
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ more_room = 4 # 过小会导致一些问题
+ left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+
+ around_eye_face_mask = face_mask & eye_prior_reigon
+ face_mask = around_eye_face_mask
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
+ face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
+ face_pixs = face_pixs.reshape([-1])
+ face_xys_to_erode = face_xys[~face_pixs]
+ face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+ eye_mask = (~ face_mask) & eye_prior_reigon
+
+ h_grid = np.mgrid[0:h, 0:w][0]
+ eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
+ eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 0
+ eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
+
+ eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
+ eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
+ eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
+
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
+ distances, indices = nbrs.kneighbors(eye_blink_xys)
+ bg_fg_xys = face_xys[indices[:, 0]]
+ img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+if __name__ == '__main__':
+ import imageio
+ import tqdm
+ img = cv2.imread("assets/cano_secc.png")
+ img = img / 127.5 - 1
+ img = torch.FloatTensor(img).permute(2, 0, 1)
+ fps = 25
+ writer = imageio.get_writer('demo_blink.mp4', fps=fps)
+
+ for i in tqdm.trange(33):
+ blink_percent = 0.03 * i
+ with Timer("Blink", True):
+ out_img = blink_eye_for_secc(img, blink_percent)
+ out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
+ writer.append_data(out_img)
+ writer.close()
\ No newline at end of file
diff --git a/inference/infer_utils.py b/inference/infer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb406fa8c734bd5295aae3f2a6e276f4b697da48
--- /dev/null
+++ b/inference/infer_utils.py
@@ -0,0 +1,154 @@
+import os
+import torch
+import torch.nn.functional as F
+import librosa
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+from scipy.spatial.transform import Rotation
+
+
+def load_img_to_512_hwc_array(img_name):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = cv2.resize(img, (512, 512))
+ return img
+
+def load_img_to_normalized_512_bchw_tensor(img_name):
+ img = load_img_to_512_hwc_array(img_name)
+ img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
+ return img
+
+def mirror_index(index, len_seq):
+ """
+ get mirror index when indexing a sequence and the index is larger than len_pose
+ args:
+ index: int
+ len_pose: int
+ return:
+ mirror_index: int
+ """
+ turn = index // len_seq
+ res = index % len_seq
+ if turn % 2 == 0:
+ return res # forward indexing
+ else:
+ return len_seq - res - 1 # reverse indexing
+
+def smooth_camera_sequence(camera, kernel_size=7):
+ """
+ smooth the camera trajectory (i.e., rotation & translation)...
+ args:
+ camera: [N, 25] or [N, 16]. np.ndarray
+ kernel_size: int
+ return:
+ smoothed_camera: [N, 25] or [N, 16]. np.ndarray
+ """
+ # poses: [N, 25], numpy array
+ N = camera.shape[0]
+ K = kernel_size // 2
+ poses = camera[:, :16].reshape([-1, 4, 4]).copy()
+ trans = poses[:, :3, 3].copy() # [N, 3]
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
+
+ for i in range(N):
+ start = max(0, i - K)
+ end = min(N, i + K + 1)
+ poses[i, :3, 3] = trans[start:end].mean(0)
+ try:
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
+ except:
+ if i == 0:
+ poses[i, :3, :3] = rots[i]
+ else:
+ poses[i, :3, :3] = poses[i-1, :3, :3]
+ poses = poses.reshape([-1, 16])
+ camera[:, :16] = poses
+ return camera
+
+def smooth_features_xd(in_tensor, kernel_size=7):
+ """
+ smooth the feature maps
+ args:
+ in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ kernel_size: int
+ return:
+ out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ """
+ t = in_tensor.shape[0]
+ ndim = in_tensor.ndim
+ pad = (kernel_size- 1)//2
+ in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
+ if ndim == 2: # tc
+ _,c = in_tensor.shape
+ in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 4: # tchw
+ _,c,h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 5: # tcchw, like deformation
+ _,c1,c2, h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ else: raise NotImplementedError()
+ avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
+ out_tensor = F.conv1d(in_tensor, avg_kernel)
+ if ndim == 2: # tc
+ return out_tensor.reshape([c,t]).permute(1,0)
+ elif ndim == 4: # tchw
+ return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
+ elif ndim == 5: # tcchw, like deformation
+ return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
+
+
+def extract_audio_motion_from_ref_video(video_name):
+ def save_wav16k(audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ wav16k_name = audio_name[:-4] + '_16k.wav'
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+ return wav16k_name
+
+ def get_f0( wav16k_name):
+ from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ f0 = torch.tensor(f0)
+ return f0
+
+ def get_hubert(wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ hubert = torch.tensor(hubert)
+ return hubert
+
+ def get_exp(video_name):
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
+ exp = torch.tensor(drv_motion_coeff_dict['exp'])
+ return exp
+
+ wav16k_name = save_wav16k(video_name)
+ f0 = get_f0(wav16k_name)
+ hubert = get_hubert(wav16k_name)
+ os.system(f"rm {wav16k_name}")
+ exp = get_exp(video_name)
+ target_length = min(len(exp), len(hubert)//2, len(f0)//2)
+ exp = exp[:target_length]
+ f0 = f0[:target_length*2]
+ hubert = hubert[:target_length*2]
+ return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
+
+
+if __name__ == '__main__':
+ extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
\ No newline at end of file
diff --git a/inference/real3d_infer.py b/inference/real3d_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4314e13b566acd208f9059588b50c4b09b9c91e4
--- /dev/null
+++ b/inference/real3d_infer.py
@@ -0,0 +1,542 @@
+import os
+import torch
+import torch.nn.functional as F
+import torchshow as ts
+import librosa
+import random
+import time
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+
+# common utils
+from utils.commons.hparams import hparams, set_hparams
+from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
+from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
+# 3DMM-related utils
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from data_util.face3d_helper import Face3DHelper
+from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
+from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+from deep_3drecon.secc_renderer import SECC_Renderer
+from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
+# Face Parsing
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
+# other inference utils
+from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
+from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
+from Real3DPortrait.inference.edit_secc import blink_eye_for_secc
+
+
+def read_first_frame_from_a_video(vid_name):
+ frames = []
+ cap = cv2.VideoCapture(vid_name)
+ ret, frame_bgr = cap.read()
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ return frame_rgb
+
+def analyze_weights_img(gen_output):
+ img_raw = gen_output['image_raw']
+ mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
+ mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
+ mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
+ mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
+ mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
+
+ img_raw_005_to_03 = img_raw.clone()
+ img_raw_005_to_03[~mask_005_to_03] = -1
+ img_raw_005_to_05 = img_raw.clone()
+ img_raw_005_to_05[~mask_005_to_05] = -1
+ img_raw_005_to_07 = img_raw.clone()
+ img_raw_005_to_07[~mask_005_to_07] = -1
+ img_raw_005_to_09 = img_raw.clone()
+ img_raw_005_to_09[~mask_005_to_09] = -1
+ img_raw_005_to_10 = img_raw.clone()
+ img_raw_005_to_10[~mask_005_to_10] = -1
+ ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
+
+class GeneFace2Infer:
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.device = device
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
+ self.audio2secc_model.to(device).eval()
+ self.secc2video_model.to(device).eval()
+ self.seg_model = MediapipeSegmenter()
+ self.secc_renderer = SECC_Renderer(512)
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
+
+ def load_audio2secc(self, audio2secc_dir):
+ config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
+ set_hparams(f"{config_name}", print_hparams=False)
+ self.audio2secc_dir = audio2secc_dir
+ self.audio2secc_hparams = copy.deepcopy(hparams)
+ from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ audio_in_dim = 1024
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ audio_in_dim = 13
+
+ if 'icl' in hparams['task_cls']:
+ self.use_icl_audio2motion = True
+ model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
+ else:
+ self.use_icl_audio2motion = False
+ if hparams.get("use_pitch", False) is True:
+ model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
+ else:
+ model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
+ load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
+ return model
+
+ def load_secc2video(self, head_model_dir, torso_model_dir, inp):
+ if inp is None:
+ inp = {}
+ self.head_model_dir = head_model_dir
+ self.torso_model_dir = torso_model_dir
+ if torso_model_dir != '':
+ if torso_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
+ model = OSAvatarSECC_Img2plane_Torso()
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
+ if head_model_dir != '':
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
+ else:
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
+ if head_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ model = OSAvatarSECC_Img2plane()
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
+ return model
+
+ def infer_once(self, inp):
+ self.inp = inp
+ samples = self.prepare_batch_from_inp(inp)
+ seed = inp['seed'] if inp['seed'] is not None else int(time.time())
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ out_name = self.forward_system(samples, inp)
+ return out_name
+
+ def prepare_batch_from_inp(self, inp):
+ """
+ :param inp: {'audio_source_name': (str)}
+ :return: a dict that contains the condition feature of NeRF
+ """
+ sample = {}
+ # Process Driving Motion
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ self.save_wav16k(inp['drv_audio_name'])
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ hubert = self.get_hubert(self.wav16k_name)
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ hubert = self.get_mfcc(self.wav16k_name) / 100
+
+ f0 = self.get_f0(self.wav16k_name)
+ if f0.shape[0] > len(hubert):
+ f0 = f0[:len(hubert)]
+ else:
+ num_to_pad = len(hubert) - len(f0)
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
+ t_x = hubert.shape[0]
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
+ sample.update({
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
+ 'x_mask': x_mask.cuda(),
+ 'y_mask': y_mask.cuda(),
+ })
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
+ sample['audio'] = sample['hubert']
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
+ sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+
+ # Face Parsing
+ image_name = inp['src_image_name']
+ if image_name.endswith(".mp4"):
+ img = read_first_frame_from_a_video(image_name)
+ image_name = inp['src_image_name'] = image_name[:-4] + '.png'
+ cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+ sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
+ img = load_img_to_512_hwc_array(image_name)
+ segmap = self.seg_model._cal_seg_map(img)
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ if inp['bg_image_name'] == '':
+ bg_img = extract_background([img], [segmap], 'knn')
+ else:
+ bg_img = cv2.imread(inp['bg_image_name'])
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
+ bg_img = cv2.resize(bg_img, (512,512))
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ # 3DMM, get identity code and camera pose
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
+ assert coeff_dict is not None
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
+ sample['id'] = src_id.repeat([t_x//2,1])
+
+ # get the src_kp for torso model
+ src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
+ src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
+ sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
+
+ # get camera pose file
+ # random.seed(time.time())
+ inp['drv_pose_name'] = inp['drv_pose_name']
+ print(f"| To extract pose from {inp['drv_pose_name']}")
+
+ # extract camera pose
+ if inp['drv_pose_name'] == 'static':
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
+ else: # from file
+ if inp['drv_pose_name'].endswith('.mp4'):
+ # extract coeff from video
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
+ else:
+ # load from npy
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
+ len_pose = len(eulers)
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
+ sample['euler'] = eulers[index_lst]
+ sample['trans'] = trans[index_lst]
+
+ # fix the z axis
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
+
+ # mapping to the init pose
+ if inp.get("map_to_init_pose", 'False') == 'True':
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
+ sample['euler'] = sample['euler'] + diff_euler
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
+ sample['trans'] = sample['trans'] + diff_trans
+
+ # prepare camera
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ # smooth camera
+ camera_smo_ksize = 7
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
+ camera = torch.tensor(camera).cuda().float()
+ sample['camera'] = camera
+
+ return sample
+
+ @torch.no_grad()
+ def get_hubert(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ return hubert
+
+ def get_mfcc(self, wav16k_name):
+ from utils.audio import librosa_wav2mfcc
+ hparams['fft_size'] = 1200
+ hparams['win_size'] = 1200
+ hparams['hop_size'] = 480
+ hparams['audio_num_mel_bins'] = 80
+ hparams['fmin'] = 80
+ hparams['fmax'] = 12000
+ hparams['audio_sample_rate'] = 24000
+ mfcc = librosa_wav2mfcc(wav16k_name,
+ fft_size=hparams['fft_size'],
+ hop_size=hparams['hop_size'],
+ win_length=hparams['win_size'],
+ num_mels=hparams['audio_num_mel_bins'],
+ fmin=hparams['fmin'],
+ fmax=hparams['fmax'],
+ sample_rate=hparams['audio_sample_rate'],
+ center=True)
+ mfcc = np.array(mfcc).reshape([-1, 13])
+ len_mel = mfcc.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
+ return mfcc
+
+ @torch.no_grad()
+ def forward_audio2secc(self, batch, inp=None):
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ # audio-to-exp
+ ret = {}
+ pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'],)
+ print("| audio-to-motion finished")
+ if pred.shape[-1] == 144:
+ id = ret['pred'][0][:,:80]
+ exp = ret['pred'][0][:,80:]
+ else:
+ id = batch['id']
+ exp = ret['pred'][0]
+ if len(id) < len(exp): # happens when use ICL
+ id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
+ batch['id'] = id
+ batch['exp'] = exp
+ else:
+ drv_motion_coeff_dict = self.drv_motion_coeff_dict
+ batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
+
+ batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
+ if self.use_icl_audio2motion:
+ self.audio2secc_model.empty_context()
+ return batch
+
+ @torch.no_grad()
+ def get_driving_motion(self, id, exp, euler, trans, batch, inp):
+ zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
+ zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
+ # render the secc given the id,exp
+ with torch.no_grad():
+ chunk_size = 50
+ drv_secc_color_lst = []
+ num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
+ for i in tqdm.trange(num_iters, desc="rendering drv secc"):
+ torch.cuda.empty_cache()
+ face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
+ drv_secc_color_lst.append(drv_secc_color.cpu())
+ drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
+ _, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
+ _, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
+ batch['drv_secc'] = drv_secc_colors.cuda()
+ batch['src_secc'] = src_secc_color.cuda()
+ batch['cano_secc'] = cano_secc_color.cuda()
+
+ # blinking secc
+ if inp['blink_mode'] == 'period':
+ period = 5 # second
+
+ for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
+ if i % (25*period) == 0:
+ blink_dur_frames = random.randint(8, 12)
+ for offset in range(blink_dur_frames):
+ j = offset + i
+ if j >= len(drv_secc_colors)-1: break
+ def blink_percent_fn(t, T):
+ return -4/T**2 * t**2 + 4/T * t
+ blink_percent = blink_percent_fn(offset, blink_dur_frames)
+ secc = batch['drv_secc'][j]
+ out_secc = blink_eye_for_secc(secc, blink_percent)
+ out_secc = out_secc.cuda()
+ batch['drv_secc'][j] = out_secc
+
+ # get the drv_kp for torso model, using the transformed trajectory
+ drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
+
+ drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
+ batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
+ return batch
+
+ @torch.no_grad()
+ def forward_secc2video(self, batch, inp=None):
+ num_frames = len(batch['drv_secc'])
+ camera = batch['camera']
+ src_kps = batch['src_kp']
+ drv_kps = batch['drv_kp']
+ cano_secc_color = batch['cano_secc']
+ src_secc_color = batch['src_secc']
+ drv_secc_colors = batch['drv_secc']
+ ref_img_gt = batch['ref_gt_img']
+ ref_img_head = batch['ref_head_img']
+ ref_torso_img = batch['ref_torso_img']
+ bg_img = batch['bg_img']
+ segmap = batch['segmap']
+
+ # smooth torso drv_kp
+ torso_smo_ksize = 7
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
+
+ # forward renderer
+ img_raw_lst = []
+ img_lst = []
+ depth_img_lst = []
+ with torch.no_grad():
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+ if i == 0:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
+ else:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ img_lst.append(gen_output['image'])
+ img_raw_lst.append(gen_output['image_raw'])
+ depth_img_lst.append(gen_output['image_depth'])
+
+ # save demo video
+ depth_imgs = torch.cat(depth_img_lst)
+ imgs = torch.cat(img_lst)
+ imgs_raw = torch.cat(img_raw_lst)
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
+
+ if inp['out_mode'] == 'concat_debug':
+ secc_img = secc_img.cpu()
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
+
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
+ depth_img = depth_img.repeat([1,3,1,1])
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
+ depth_img = depth_img * 2 - 1
+ depth_img = depth_img.clamp(-1,1)
+
+ secc_img = secc_img / 127.5 - 1
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
+ elif inp['out_mode'] == 'final':
+ imgs = imgs.cpu()
+ elif inp['out_mode'] == 'debug':
+ raise NotImplementedError("to do: save separate videos")
+ imgs = imgs.clamp(-1,1)
+
+ import imageio
+ debug_name = 'demo.mp4'
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
+
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
+ writer.append_data(out_imgs[i])
+ writer.close()
+
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
+ try:
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
+ except: pass
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ os.system(f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}")
+ os.system(f"rm {debug_name}")
+ os.system(f"rm {self.wav16k_name}")
+ else:
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
+ os.system(f"mv {debug_name} {out_fname}")
+ print(f"Saved at {out_fname}")
+ return out_fname
+
+ @torch.no_grad()
+ def forward_system(self, batch, inp):
+ self.forward_audio2secc(batch, inp)
+ out_fname = self.forward_secc2video(batch, inp)
+ return out_fname
+
+ @classmethod
+ def example_run(cls, inp=None):
+ inp_tmp = {
+ 'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
+ 'src_image_name': 'data/raw/val_imgs/Macron.png'
+ }
+ if inp is not None:
+ inp_tmp.update(inp)
+ inp = inp_tmp
+
+ infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
+ infer_instance.infer_once(inp)
+
+ ##############
+ # IO-related
+ ##############
+ def save_wav16k(self, audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ wav16k_name = audio_name[:-4] + '_16k.wav'
+ self.wav16k_name = wav16k_name
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+
+ def get_f0(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(self.wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ return f0
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae', type=str)
+ parser.add_argument("--head_ckpt", default='', type=str)
+ parser.add_argument("--torso_ckpt", default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig', type=str)
+ parser.add_argument("--src_img", default='', type=str) # data/raw/examples/Macron.png
+ parser.add_argument("--bg_img", default='', type=str) # data/raw/examples/bg.png
+ parser.add_argument("--drv_aud", default='', type=str) # data/raw/examples/Obama_5s.wav
+ parser.add_argument("--drv_pose", default='static', type=str) # data/raw/examples/May_5s.mp4
+ parser.add_argument("--blink_mode", default='none', type=str) # none | period
+ parser.add_argument("--temperature", default=0.2, type=float) # sampling temperature in audio2motion, higher -> more diverse, less accurate
+ parser.add_argument("--mouth_amp", default=0.45, type=float) # scale of predicted mouth, enabled in audio-driven
+ parser.add_argument("--head_torso_threshold", default=0.9, type=float, help="0.1~1.0, turn up this value if the hair is translucent")
+ parser.add_argument("--out_name", default='') # output filename
+ parser.add_argument("--out_mode", default='final') # final: only output talking head video; concat_debug: talking head with internel features
+ parser.add_argument("--map_to_init_pose", default='True') # whether to map the pose of first frame to source image
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
+
+ args = parser.parse_args()
+
+ inp = {
+ 'a2m_ckpt': args.a2m_ckpt,
+ 'head_ckpt': args.head_ckpt,
+ 'torso_ckpt': args.torso_ckpt,
+ 'src_image_name': args.src_img,
+ 'bg_image_name': args.bg_img,
+ 'drv_audio_name': args.drv_aud,
+ 'drv_pose_name': args.drv_pose,
+ 'blink_mode': args.blink_mode,
+ 'temperature': args.temperature,
+ 'mouth_amp': args.mouth_amp,
+ 'out_name': args.out_name,
+ 'out_mode': args.out_mode,
+ 'map_to_init_pose': args.map_to_init_pose,
+ 'head_torso_threshold': args.head_torso_threshold,
+ 'seed': args.seed,
+ }
+
+ GeneFace2Infer.example_run(inp)
\ No newline at end of file
diff --git a/insta.sh b/insta.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d8a4f7f20406600990707675c5faa4c51f31c35f
--- /dev/null
+++ b/insta.sh
@@ -0,0 +1,18 @@
+
+#conda create -n real3dportrait python=3.9
+#conda activate real3dportrait
+conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
+
+### We recommend torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV for some network structure
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
+
+# other dependencies
+pip install -r docs/prepare_env/requirements.txt -v
diff --git a/modules/audio2motion/cnn_models.py b/modules/audio2motion/cnn_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58e8c472349f59ab1f733a384906644a0b796c2
--- /dev/null
+++ b/modules/audio2motion/cnn_models.py
@@ -0,0 +1,359 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def init_weights_func(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv1d") != -1:
+ torch.nn.init.xavier_uniform_(m.weight)
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+
+class ResidualBlock(nn.Module):
+ """Implements conv->PReLU->norm n-times"""
+
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
+ c_multiple=2, ln_eps=1e-12, bias=False):
+ super(ResidualBlock, self).__init__()
+
+ if norm_type == 'bn':
+ norm_builder = lambda: nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm_builder = lambda: nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
+ else:
+ norm_builder = lambda: nn.Identity()
+
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2, bias=bias),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias),
+ )
+ for _ in range(n)
+ ]
+
+ self.blocks = nn.ModuleList(self.blocks)
+ self.dropout = dropout
+
+ def forward(self, x):
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ for b in self.blocks:
+ x_ = b(x)
+ if self.dropout > 0 and self.training:
+ x_ = F.dropout(x_, self.dropout, training=self.training)
+ x = x + x_
+ x = x * nonpadding
+ return x
+
+
+class ConvBlocks(nn.Module):
+ """Decodes the expanded phoneme encoding into spectrograms"""
+
+ def __init__(self, channels, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False):
+ super(ConvBlocks, self).__init__()
+ self.is_BTC = is_BTC
+ self.res_blocks = nn.Sequential(
+ *[ResidualBlock(channels, kernel_size, d,
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
+ dropout=dropout, ln_eps=ln_eps, bias=bias)
+ for d in dilations],
+ )
+ if norm_type == 'bn':
+ norm = nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm = nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm = nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm = LayerNorm(channels, dim=1, eps=ln_eps)
+ self.last_norm = norm
+ self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias)
+ if init_weights:
+ self.apply(init_weights_func)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ if self.is_BTC:
+ x = x.transpose(1, 2) # [B, C, T]
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ x = self.res_blocks(x) * nonpadding
+ x = self.last_norm(x) * nonpadding
+ x = self.post_net1(x) * nonpadding
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+
+class SeqLevelConvolutionalModel(nn.Module):
+ def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'):
+ nn.Module.__init__(self)
+ self.audio_feat_type = audio_feat_type
+ if audio_feat_type == 'ppg':
+ self.audio_encoder = nn.Sequential(*[
+ nn.Conv1d(29, 48, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1),
+ nn.GELU(),
+ nn.Conv1d(48, 48, 3, 1, 1, bias=False)
+ ])
+ self.energy_encoder = nn.Sequential(*[
+ nn.Conv1d(1, 16, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1),
+ nn.GELU(),
+ nn.Conv1d(16, 16, 3, 1, 1, bias=False)
+ ])
+ elif audio_feat_type == 'mel':
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(80, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1),
+ nn.GELU(),
+ nn.Conv1d(64, 64, 3, 1, 1, bias=False)
+ ])
+ else:
+ raise NotImplementedError("now only ppg or mel are supported!")
+
+ self.style_encoder = nn.Sequential(*[
+ nn.Linear(135, 256),
+ nn.GELU(),
+ nn.Linear(256, 256)
+ ])
+
+ if backbone_type == 'resnet':
+ self.backbone = ResNetBackbone()
+ elif backbone_type == 'unet':
+ self.backbone = UNetBackbone()
+ elif backbone_type == 'resblocks':
+ self.backbone = ResBlocksBackbone()
+ else:
+ raise NotImplementedError("Now only resnet and unet are supported!")
+
+ self.out_layer = nn.Sequential(
+ nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1),
+ nn.Conv1d(512, 64, 3, 1, 1, bias=False),
+ nn.PReLU(),
+ nn.Conv1d(64, out_dim, 3, 1, 1, bias=False)
+ )
+ self.feat_dropout = nn.Dropout(p=dropout)
+
+ @property
+ def device(self):
+ return self.backbone.parameters().__next__().device
+
+ def forward(self, batch, ret, log_dict=None):
+ style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device)
+ style_feat = self.style_encoder(style) # [B,C=135] => [B,C=128]
+
+ if self.audio_feat_type == 'ppg':
+ audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device)
+ audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=29] => [B,T,C=48]
+ energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=1] => [B,T,C=16]
+ feat = torch.cat([audio_feat, energy_feat], dim=2) # [B,T,C=48+16]
+ elif self.audio_feat_type == 'mel':
+ mel = batch['mel'].to(self.device)
+ feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=64]
+
+ feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask)
+
+ out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T//2,C=256] => [B,T//2,C=64]
+
+ ret['pred'] = out
+ ret['mask'] = x_mask
+ return out
+
+
+class ResBlocksBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(ResBlocksBackbone,self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
+
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_2(x) * x_mask # [B, C, T/2]
+
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/2]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/2]
+
+ x = self.resblocks_3(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
+
+ x = x.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+
+class ResNetBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(ResNetBackbone,self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
+
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
+ x = self.downsampler(x) * x_mask # [B, C, T/4]
+ x = self.resblocks_2(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
+ x = self.downsampler(x) * x_mask # [B, C, T/8]
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
+ x = self.resblocks_3(x) * x_mask # [B, C, T/8]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
+ x = self.upsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
+
+ x = x.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+class UNetBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(UNetBackbone, self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [768 = c3(512) + c2(256)]
+ self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [640 = c4(512) + c1(128)]
+
+ self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear')
+ self.upsampler = nn.Upsample(scale_factor=2, mode='linear')
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x0 = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x0) * x_mask # [B, C, T/2]
+ x1 = self.resblocks_1(x) * x_mask # [B, C, T/2]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
+ x = self.downsampler(x1) * x_mask # [B, C, T/4]
+ x2 = self.resblocks_2(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
+ x = self.downsampler(x2) * x_mask # [B, C, T/8]
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
+ x3 = self.resblocks_3(x) * x_mask # [B, C, T/8]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/4]
+ x = self.upsampler(x3) * x_mask # [B, C, T/4]
+ x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) #
+ x4 = self.resblocks_4(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
+ x = self.upsampler(x4) * x_mask # [B, C, T/2]
+ x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1)
+ x5 = self.resblocks_5(x) * x_mask # [B, C, T/2]
+
+ x = x5.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+if __name__ == '__main__':
+ pass
diff --git a/modules/audio2motion/flow_base.py b/modules/audio2motion/flow_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ff1c626cc3e4aef72406e16971db7331aa5c85
--- /dev/null
+++ b/modules/audio2motion/flow_base.py
@@ -0,0 +1,838 @@
+import scipy
+from scipy import linalg
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+
+import modules.audio2motion.utils as utils
+from modules.audio2motion.transformer_models import FFTBlocks
+from utils.commons.hparams import hparams
+
+
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0,
+ p_dropout=0, share_cond_layers=False):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ assert (hidden_channels % 2 == 0)
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.share_cond_layers = share_cond_layers
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+
+ self.drop = nn.Dropout(p_dropout)
+
+ self.use_adapters = hparams.get("use_adapters", False)
+ if self.use_adapters:
+ self.adapter_layers = torch.nn.ModuleList()
+
+ if gin_channels != 0 and not share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ if self.use_adapters:
+ adapter_layer = MlpAdapter(in_out_dim=res_skip_channels, hid_dim=res_skip_channels//4)
+ self.adapter_layers.append(adapter_layer)
+
+ def forward(self, x, x_mask=None, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None and not self.share_cond_layers:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ x_in = self.drop(x_in)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if self.use_adapters:
+ res_skip_acts = self.adapter_layers[i](res_skip_acts.transpose(1,2)).transpose(1,2)
+ if i < self.n_layers - 1:
+ x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+
+ def enable_adapters(self):
+ if not self.use_adapters:
+ return
+ for adapter_layer in self.adapter_layers:
+ adapter_layer.enable()
+
+ def disable_adapters(self):
+ if not self.use_adapters:
+ return
+ for adapter_layer in self.adapter_layers:
+ adapter_layer.disable()
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+
+class ActNorm(nn.Module):
+ def __init__(self, channels, ddi=False, **kwargs):
+ super().__init__()
+ self.channels = channels
+ self.initialized = not ddi
+
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ if x_mask is None:
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
+ x_len = torch.sum(x_mask, [1, 2])
+ if not self.initialized:
+ self.initialize(x, x_mask)
+ self.initialized = True
+
+ if reverse:
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
+ logdet = torch.sum(-self.logs) * x_len
+ else:
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
+ logdet = torch.sum(self.logs) * x_len # [b]
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+ def set_ddi(self, ddi):
+ self.initialized = not ddi
+
+ def initialize(self, x, x_mask):
+ with torch.no_grad():
+ denom = torch.sum(x_mask, [0, 2])
+ m = torch.sum(x * x_mask, [0, 2]) / denom
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
+ v = m_sq - (m ** 2)
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
+
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
+
+ self.bias.data.copy_(bias_init)
+ self.logs.data.copy_(logs_init)
+
+
+class InvConvNear(nn.Module):
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
+ super().__init__()
+ assert (n_split % 2 == 0)
+ self.channels = channels
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.no_jacobian = no_jacobian
+
+ w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
+ if torch.det(w_init) < 0:
+ w_init[:, 0] = -1 * w_init[:, 0]
+ self.lu = lu
+ if lu:
+ # LU decomposition can slightly speed up the inverse
+ np_p, np_l, np_u = linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
+ eye = np.eye(*w_init.shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
+ self.register_buffer('eye', torch.Tensor(eye))
+ else:
+ self.weight = nn.Parameter(w_init)
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ b, c, t = x.size()
+ assert (c % self.n_split == 0)
+ if x_mask is None:
+ x_mask = 1
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
+
+ if self.lu:
+ self.weight, log_s = self._get_weight()
+ logdet = log_s.sum()
+ logdet = logdet * (c / self.n_split) * x_len
+ else:
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
+
+ if reverse:
+ if hasattr(self, "weight_inv"):
+ weight = self.weight_inv
+ else:
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
+ logdet = -logdet
+ else:
+ weight = self.weight
+ if self.no_jacobian:
+ logdet = 0
+
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
+ z = F.conv2d(x, weight)
+
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
+ return z, logdet
+
+ def _get_weight(self):
+ l, log_s, u = self.l, self.log_s, self.u
+ l = l * self.l_mask + self.eye
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
+ weight = torch.matmul(self.p, torch.matmul(l, u))
+ return weight, log_s
+
+ def store_inverse(self):
+ weight, _ = self._get_weight()
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
+
+
+class InvConv(nn.Module):
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
+ super().__init__()
+ w_shape = [channels, channels]
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
+ LU_decomposed = lu
+ if not LU_decomposed:
+ # Sample a random orthogonal matrix:
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
+ else:
+ np_p, np_l, np_u = linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
+ eye = np.eye(*w_shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
+ self.l_mask = torch.Tensor(l_mask)
+ self.eye = torch.Tensor(eye)
+ self.w_shape = w_shape
+ self.LU = LU_decomposed
+ self.weight = None
+
+ def get_weight(self, device, reverse):
+ w_shape = self.w_shape
+ self.p = self.p.to(device)
+ self.sign_s = self.sign_s.to(device)
+ self.l_mask = self.l_mask.to(device)
+ self.eye = self.eye.to(device)
+ l = self.l * self.l_mask + self.eye
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
+ dlogdet = self.log_s.sum()
+ if not reverse:
+ w = torch.matmul(self.p, torch.matmul(l, u))
+ else:
+ l = torch.inverse(l.double()).float()
+ u = torch.inverse(u.double()).float()
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ """
+ log-det = log|abs(|W|)| * pixels
+ """
+ b, c, t = x.size()
+ if x_mask is None:
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+ logdet = 0
+ if not reverse:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet + dlogdet * x_len
+ return z, logdet
+ else:
+ if self.weight is None:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ else:
+ weight, dlogdet = self.weight, self.dlogdet
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet - dlogdet * x_len
+ return z, logdet
+
+ def store_inverse(self):
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+
+ def store_inverse(self):
+ pass
+
+
+class CouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False,
+ share_cond_layers=False, wn=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ start = torch.nn.utils.weight_norm(start)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels,
+ p_dropout, share_cond_layers)
+ if wn is not None:
+ self.wn.in_layers = wn.in_layers
+ self.wn.res_skip_layers = wn.res_skip_layers
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.wn(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ self.wn.remove_weight_norm()
+
+
+class GlowFFTBlocks(FFTBlocks):
+ def __init__(self, hidden_size=128, gin_channels=256, num_layers=2, ffn_kernel_size=5,
+ dropout=None, num_heads=4, use_pos_embed=True, use_last_norm=True,
+ norm='ln', use_pos_embed_alpha=True):
+ super().__init__(hidden_size, num_layers, ffn_kernel_size, dropout, num_heads, use_pos_embed,
+ use_last_norm, norm, use_pos_embed_alpha)
+ self.inp_proj = nn.Conv1d(hidden_size + gin_channels, hidden_size, 1)
+
+ def forward(self, x, x_mask=None, g=None):
+ """
+ :param x: [B, C_x, T]
+ :param x_mask: [B, 1, T]
+ :param g: [B, C_g, T]
+ :return: [B, C_x, T]
+ """
+ if g is not None:
+ x = self.inp_proj(torch.cat([x, g], 1))
+ x = x.transpose(1, 2)
+ x = super(GlowFFTBlocks, self).forward(x, x_mask[:, 0] == 0)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TransformerCouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.fft_blocks = GlowFFTBlocks(
+ hidden_size=hidden_channels,
+ ffn_kernel_size=3,
+ gin_channels=gin_channels,
+ num_layers=n_layers)
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.fft_blocks(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+
+class FreqFFTCouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ hs = hidden_channels
+ stride = 8
+ self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2,
+ stride=stride, padding=stride // 2)
+ end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = nn.Sequential(
+ nn.Conv2d(hs * 3, hs, 3, 1, 1),
+ nn.ReLU(),
+ nn.GroupNorm(4, hs),
+ nn.Conv2d(hs, hs, 3, 1, 1),
+ end
+ )
+ self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers)
+ self.fft_h = nn.Sequential(
+ nn.Conv1d(hs, hs, 3, 1, 1),
+ nn.ReLU(),
+ nn.Conv1d(hs, hs, 3, 1, 1),
+ )
+ self.fft_g = nn.Sequential(
+ nn.Conv1d(
+ gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2),
+ Permute(0, 2, 1),
+ FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers),
+ Permute(0, 2, 1),
+ )
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ g_, _ = utils.unsqueeze(g)
+ g_mel = g_[:, :80]
+ g_txt = g_[:, 80:]
+ g_mel, _ = utils.squeeze(g_mel)
+ g_txt, _ = utils.squeeze(g_txt) # [B, C, T]
+
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+ x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1)
+ x = self.start(x) # [B, C, N_bins, T]
+ B, C, N_bins, T = x.shape
+
+ x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C))
+ x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1)
+ # x_v = x
+
+ x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T))
+ x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3)
+ # x_h = x
+
+ x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1)
+ x = torch.cat([x_v, x_h, x_g], 1)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, 0]
+ logs = out[:, 1]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ nn_type='wn'):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ if nn_type == 'wn':
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
+ gin_channels=gin_channels)
+ # elif nn_type == 'conv':
+ # self.enc = ConditionalConvBlocks(
+ # hidden_channels, gin_channels, hidden_channels, [1] * n_layers, kernel_size,
+ # layers_in_block=1, is_BTC=False)
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask=x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = -torch.sum(logs, [1, 2])
+ return x, logdet
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0,
+ nn_type='wn'):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=gin_channels, mean_only=True, nn_type=nn_type))
+ self.flows.append(Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+
+class Glow(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_blocks,
+ n_layers,
+ p_dropout=0.,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=0,
+ inv_conv_type='near',
+ share_cond_layers=False,
+ share_wn_layers=0,
+ ):
+ super().__init__()
+ """
+ Note that regularization likes weight decay can leads to Nan error!
+ """
+
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_blocks = n_blocks
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.sigmoid_scale = sigmoid_scale
+ self.gin_channels = gin_channels
+ self.share_cond_layers = share_cond_layers
+ if gin_channels != 0 and share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+ wn = None
+ self.flows = nn.ModuleList()
+ for b in range(n_blocks):
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
+ if inv_conv_type == 'near':
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
+ if inv_conv_type == 'invconv':
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
+ if share_wn_layers > 0:
+ if b % share_wn_layers == 0:
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
+ p_dropout, share_cond_layers)
+ self.flows.append(
+ CouplingBlock(
+ in_channels * n_sqz,
+ hidden_channels,
+ kernel_size=kernel_size,
+ dilation_rate=dilation_rate,
+ n_layers=n_layers,
+ gin_channels=gin_channels * n_sqz,
+ p_dropout=p_dropout,
+ sigmoid_scale=sigmoid_scale,
+ share_cond_layers=share_cond_layers,
+ wn=wn
+ ))
+
+ def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
+ """
+ x: [B,T,C]
+ x_mask: [B,T]
+ g: [B,T,C]
+ """
+ x = x.transpose(1,2)
+ x_mask = x_mask.unsqueeze(1)
+ if g is not None:
+ g = g.transpose(1,2)
+
+ logdet_tot = 0
+ if not reverse:
+ flows = self.flows
+ else:
+ flows = reversed(self.flows)
+ if return_hiddens:
+ hs = []
+ if self.n_sqz > 1:
+ x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
+ if g is not None:
+ g, _ = utils.squeeze(g, x_mask, self.n_sqz)
+ x_mask = x_mask_
+ if self.share_cond_layers and g is not None:
+ g = self.cond_layer(g)
+ for f in flows:
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
+ if return_hiddens:
+ hs.append(x)
+ logdet_tot += logdet
+ if self.n_sqz > 1:
+ x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
+
+ x = x.transpose(1,2)
+ if return_hiddens:
+ return x, logdet_tot, hs
+ return x, logdet_tot
+
+ def store_inverse(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+ for f in self.flows:
+ f.store_inverse()
+
+
+if __name__ == '__main__':
+ model = Glow(in_channels=64,
+ hidden_channels=128,
+ kernel_size=5,
+ dilation_rate=1,
+ n_blocks=12,
+ n_layers=4,
+ p_dropout=0.0,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=80
+ )
+ exp = torch.rand([1,1440,64])
+ mel = torch.rand([1,1440,80])
+ x_mask = torch.ones([1,1440],dtype=torch.float32)
+ y, logdet = model(exp, x_mask,g=mel, reverse=False)
+ pred_exp, logdet = model(y, x_mask,g=mel, reverse=False)
+ # y: [b, t,c=64]
+ print(" ")
\ No newline at end of file
diff --git a/modules/audio2motion/multi_length_disc.py b/modules/audio2motion/multi_length_disc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a57df2cef929691f2f1fa41981ed8316ff5dce6
--- /dev/null
+++ b/modules/audio2motion/multi_length_disc.py
@@ -0,0 +1,340 @@
+import numpy as np
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.audio2motion.cnn_models import LambdaLayer
+
+
+class Discriminator1DFactory(nn.Module):
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
+ super(Discriminator1DFactory, self).__init__()
+ padding = kernel_size // 2
+
+ def discriminator_block(in_filters, out_filters, first=False):
+ """
+ Input: (B, c, T)
+ Output:(B, c, T//2)
+ """
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
+ block = [
+ conv, # padding = kernel//2
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25)
+ ]
+ if norm_type == 'bn' and not first:
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
+ if norm_type == 'in' and not first:
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
+ block = nn.Sequential(*block)
+ return block
+
+ if time_length >= 8:
+ self.model = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+ ds_size = time_length // (2 ** 3)
+ elif time_length == 3:
+ self.model = nn.ModuleList([
+ nn.Sequential(*[
+ nn.Conv1d(in_dim, hidden_size, 3, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.BatchNorm1d(hidden_size, 0.8),
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.BatchNorm1d(hidden_size, 0.8)
+ ])
+ ])
+ ds_size = 1
+ elif time_length == 1:
+ self.model = nn.ModuleList([
+ nn.Sequential(*[
+ nn.Linear(in_dim, hidden_size),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.Linear(hidden_size, hidden_size),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ ])
+ ])
+ ds_size = 1
+
+ self.adv_layer = nn.Linear(hidden_size * ds_size, 1)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, C, T]
+ :return: validity: [B, 1], h: List of hiddens
+ """
+ h = []
+ if x.shape[-1] == 1:
+ x = x.squeeze(-1)
+ for l in self.model:
+ x = l(x)
+ h.append(x)
+ if x.ndim == 2:
+ b, ct = x.shape
+ use_sigmoid = True
+ else:
+ b, c, t = x.shape
+ ct = c * t
+ use_sigmoid = False
+ x = x.view(b, ct)
+ validity = self.adv_layer(x) # [B, 1]
+ if use_sigmoid:
+ validity = torch.sigmoid(validity)
+ return validity, h
+
+
+class CosineDiscriminator1DFactory(nn.Module):
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
+ super().__init__()
+ padding = kernel_size // 2
+
+ def discriminator_block(in_filters, out_filters, first=False):
+ """
+ Input: (B, c, T)
+ Output:(B, c, T//2)
+ """
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
+ block = [
+ conv, # padding = kernel//2
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25)
+ ]
+ if norm_type == 'bn' and not first:
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
+ if norm_type == 'in' and not first:
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
+ block = nn.Sequential(*block)
+ return block
+
+ self.model1 = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+
+ self.model2 = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+
+ self.relu = nn.ReLU()
+ def forward(self, x1, x2):
+ """
+
+ :param x1: [B, C, T]
+ :param x2: [B, C, T]
+ :return: validity: [B, 1], h: List of hiddens
+ """
+ h1, h2 = [], []
+ for l in self.model1:
+ x1 = l(x1)
+ h1.append(x1)
+ for l in self.model2:
+ x2 = l(x2)
+ h2.append(x1)
+ b,c,t = x1.shape
+ x1 = x1.view(b, c*t)
+ x2 = x2.view(b, c*t)
+ x1 = self.relu(x1)
+ x2 = self.relu(x2)
+ # x1 = F.normalize(x1, p=2, dim=1)
+ # x2 = F.normalize(x2, p=2, dim=1)
+ validity = F.cosine_similarity(x1, x2)
+ return validity, [h1,h2]
+
+
+class MultiWindowDiscriminator(nn.Module):
+ def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'):
+ super(MultiWindowDiscriminator, self).__init__()
+ self.win_lengths = time_lengths
+ self.reduction = reduction
+ self.disc_type = disc_type
+
+ if cond_dim > 0:
+ self.use_cond = True
+ self.cond_proj_layers = nn.ModuleList()
+ self.in_proj_layers = nn.ModuleList()
+ else:
+ self.use_cond = False
+
+ self.conv_layers = nn.ModuleList()
+ for time_length in time_lengths:
+ conv_layer = [
+ Discriminator1DFactory(
+ time_length, kernel_size, in_dim=64, hidden_size=hidden_size,
+ norm_type=norm_type) if self.disc_type == 'standard'
+ else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64,
+ hidden_size=hidden_size,norm_type=norm_type)
+ ]
+ self.conv_layers += conv_layer
+ if self.use_cond:
+ self.cond_proj_layers.append(nn.Linear(cond_dim, 64))
+ self.in_proj_layers.append(nn.Linear(in_dim, 64))
+
+ def clip(self, x, cond, x_len, win_length, start_frames=None):
+ '''Ramdom clip x to win_length.
+ Args:
+ x (tensor) : (B, T, C).
+ cond (tensor) : (B, T, H).
+ x_len (tensor) : (B,).
+ win_length (int): target clip length
+
+ Returns:
+ (tensor) : (B, c_in, win_length, n_bins).
+
+ '''
+ clip_from_same_frame = start_frames is None
+ T_start = 0
+ # T_end = x_len.max() - win_length
+ T_end = x_len.min() - win_length
+ if T_end < 0:
+ return None, None, start_frames
+ T_end = T_end.item()
+ if start_frames is None:
+ start_frame = np.random.randint(low=T_start, high=T_end + 1)
+ start_frames = [start_frame] * x.size(0)
+ else:
+ start_frame = start_frames[0]
+
+
+ if clip_from_same_frame:
+ x_batch = x[:, start_frame: start_frame + win_length, :]
+ c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None
+ else:
+ x_lst = []
+ c_lst = []
+ for i, start_frame in enumerate(start_frames):
+ x_lst.append(x[i, start_frame: start_frame + win_length, :])
+ if cond is not None:
+ c_lst.append(cond[i, start_frame: start_frame + win_length, :])
+ x_batch = torch.stack(x_lst, dim=0)
+ if cond is None:
+ c_batch = None
+ else:
+ c_batch = torch.stack(c_lst, dim=0)
+ return x_batch, c_batch, start_frames
+
+ def forward(self, x, x_len, cond=None, start_frames_wins=None):
+ '''
+ Args:
+ x (tensor): input mel, (B, T, C).
+ x_length (tensor): len of per mel. (B,).
+
+ Returns:
+ tensor : (B).
+ '''
+ validity = []
+ if start_frames_wins is None:
+ start_frames_wins = [None] * len(self.conv_layers)
+ h = []
+ for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins):
+ x_clip, c_clip, start_frames = self.clip(
+ x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C)
+ start_frames_wins[i] = start_frames
+ if x_clip is None:
+ continue
+ if self.disc_type == 'standard':
+ if self.use_cond:
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
+ c_clip = self.cond_proj_layers[i](c_clip)
+ x_clip = x_clip + c_clip
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2))
+ elif self.disc_type == 'cosine':
+ assert self.use_cond is True
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
+ c_clip = self.cond_proj_layers[i](c_clip)
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2))
+ else:
+ raise NotImplementedError
+
+ h += h_
+ validity.append(validity_pred)
+ if len(validity) != len(self.conv_layers):
+ return None, start_frames_wins, h
+ if self.reduction == 'sum':
+ validity = sum(validity) # [B]
+ elif self.reduction == 'stack':
+ validity = torch.stack(validity, -1) # [B, W_L]
+ return validity, start_frames_wins, h
+
+
+class Discriminator(nn.Module):
+ def __init__(self, x_dim=80, y_dim=64, disc_type='standard',
+ uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)):
+ """_summary_
+
+ Args:
+ time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128].
+ x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec.
+ y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose).
+ kernel (tuple, optional): _description_. Defaults to (3, 3).
+ c_in (int, optional): _description_. Defaults to 1.
+ hidden_size (int, optional): _description_. Defaults to 128.
+ norm_type (str, optional): _description_. Defaults to 'bn'.
+ reduction (str, optional): _description_. Defaults to 'sum'.
+ uncond_disc (bool, optional): _description_. Defaults to False.
+ """
+ super(Discriminator, self).__init__()
+ self.time_lengths = time_lengths
+ self.x_dim, self.y_dim = x_dim, y_dim
+ self.disc_type = disc_type
+ self.reduction = reduction
+ self.uncond_disc = uncond_disc
+
+ if uncond_disc:
+ self.x_dim = 0
+ cond_dim = 0
+
+ else:
+ cond_dim = 64
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False)
+ ])
+
+ self.disc = MultiWindowDiscriminator(
+ time_lengths=self.time_lengths,
+ in_dim=self.y_dim,
+ cond_dim=cond_dim,
+ kernel_size=kernel_size,
+ hidden_size=hidden_size, norm_type=norm_type,
+ reduction=reduction,
+ disc_type=disc_type
+ )
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ @property
+ def device(self):
+ return self.disc.parameters().__next__().device
+
+ def forward(self,x, batch, start_frames_wins=None):
+ """
+
+ :param x: [B, T, C]
+ :param cond: [B, T, cond_size]
+ :return:
+ """
+ x = x.to(self.device)
+ if not self.uncond_disc:
+ mel = self.downsampler(batch['mel'].to(self.device))
+ mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ else:
+ mel_feat = None
+ x_len = x.sum(-1).ne(0).int().sum([1])
+ disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins)
+ return disc_confidence
+
diff --git a/modules/audio2motion/transformer_base.py b/modules/audio2motion/transformer_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bbe0073907742b2921b28afed2b241b7caeb60
--- /dev/null
+++ b/modules/audio2motion/transformer_base.py
@@ -0,0 +1,988 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter
+import torch.onnx.operators
+import torch.nn.functional as F
+from collections import defaultdict
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+
+def softmax(x, dim):
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
+
+def _get_full_incremental_state_key(module_instance, key):
+ module_name = module_instance.__class__.__name__
+
+ # assign a unique ID to each module instance, so that incremental state is
+ # not shared across module instances
+ if not hasattr(module_instance, '_instance_id'):
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
+
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
+
+
+
+def get_incremental_state(module, incremental_state, key):
+ """Helper for getting incremental state for an nn.Module."""
+ full_key = _get_full_incremental_state_key(module, key)
+ if incremental_state is None or full_key not in incremental_state:
+ return None
+ return incremental_state[full_key]
+
+
+def set_incremental_state(module, incremental_state, key, value):
+ """Helper for setting incremental state for an nn.Module."""
+ if incremental_state is not None:
+ full_key = _get_full_incremental_state_key(module, key)
+ incremental_state[full_key] = value
+
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+class GroupNorm1DTBC(nn.GroupNorm):
+ def forward(self, input):
+ return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1)
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if not export and torch.cuda.is_available():
+ try:
+ from apex.normalization import FusedLayerNorm
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ except ImportError:
+ pass
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.)
+ return m
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class ConvTBC(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+ super(ConvTBC, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.padding = padding
+
+ self.weight = torch.nn.Parameter(torch.Tensor(
+ self.kernel_size, in_channels, out_channels))
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
+
+ def forward(self, input):
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ if hasattr(F, "multi_head_attention_forward"):
+ self.enable_torch_version = True
+ else:
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
+ if static_kv:
+ key_padding_mask = prev_key_padding_mask
+ else:
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_key_padding_mask'] = key_padding_mask
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class Swish(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_variables[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class CustomSwish(nn.Module):
+ def forward(self, input_tensor):
+ return Swish.apply(input_tensor)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+ if self.act == 'swish':
+ self.swish_fn = CustomSwish()
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ if self.act == 'swish':
+ x = self.swish_fn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class BatchNorm1dTBC(nn.Module):
+ def __init__(self, c):
+ super(BatchNorm1dTBC, self).__init__()
+ self.bn = nn.BatchNorm1d(c)
+
+ def forward(self, x):
+ """
+
+ :param x: [T, B, C]
+ :return: [T, B, C]
+ """
+ x = x.permute(1, 2, 0) # [B, C, T]
+ x = self.bn(x) # [B, C, T]
+ x = x.permute(2, 0, 1) # [T, B, C]
+ return x
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm1 = BatchNorm1dTBC(c)
+ elif norm == 'gn':
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm2 = BatchNorm1dTBC(c)
+ elif norm == 'gn':
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, act='gelu', norm='ln'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ if norm == 'ln':
+ self.layer_norm3 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm3 = GroupNorm1DTBC(8, c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ attn_logits = None
+ if encoder_out is not None or attn_out is not None:
+ residual = x
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ elif attn_out is not None:
+ x = self.encoder_attn.in_proj_v(attn_out)
+ if encoder_out is not None or attn_out is not None:
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
+ super().__init__()
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
+ self.norm = norm
+ if self.norm == 'bn':
+ self.norm = nn.BatchNorm1d(n_chans)
+ elif self.norm == 'in':
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
+ elif self.norm == 'gn':
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
+ elif self.norm == 'ln':
+ self.norm = LayerNorm(n_chans // 16, n_chans)
+ elif self.norm == 'wn':
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
+ self.dropout = nn.Dropout(dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ """
+
+ :param x: [B, C, T]
+ :return: [B, C, T]
+ """
+ x = self.conv(x)
+ if not isinstance(self.norm, str):
+ if self.norm == 'none':
+ pass
+ elif self.norm == 'ln':
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ x = self.norm(x)
+ x = self.relu(x)
+ x = self.dropout(x)
+ return x
+
+
+class ConvStacks(nn.Module):
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
+ dropout=0, strides=None, res=True):
+ super().__init__()
+ self.conv = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.res = res
+ self.in_proj = Linear(idim, n_chans)
+ if strides is None:
+ strides = [1] * n_layers
+ else:
+ assert len(strides) == n_layers
+ for idx in range(n_layers):
+ self.conv.append(ConvBlock(
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
+ self.out_proj = Linear(n_chans, odim)
+
+ def forward(self, x, return_hiddens=False):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ x = self.in_proj(x)
+ x = x.transpose(1, -1) # (B, idim, Tmax)
+ hiddens = []
+ for f in self.conv:
+ x_ = f(x)
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
+ hiddens.append(x)
+ x = x.transpose(1, -1)
+ x = self.out_proj(x) # (B, Tmax, H)
+ if return_hiddens:
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
+ return x, hiddens
+ return x
+
+
+class ConvGlobalStacks(nn.Module):
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0,
+ strides=[2, 2, 2, 2, 2]):
+ super().__init__()
+ self.conv = torch.nn.ModuleList()
+ self.pooling = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.in_proj = Linear(idim, n_chans)
+ for idx in range(n_layers):
+ self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx],
+ norm=norm, dropout=dropout))
+ self.pooling.append(nn.MaxPool1d(strides[idx]))
+ self.out_proj = Linear(n_chans, odim)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ x = self.in_proj(x)
+ x = x.transpose(1, -1) # (B, idim, Tmax)
+ for f, p in zip(self.conv, self.pooling):
+ x = f(x) # (B, C, T)
+ x = x.transpose(1, -1)
+ x = self.out_proj(x.mean(1)) # (B, H)
+ return x
+
+
+class ConvDecoder(nn.Module):
+ def __init__(self, c, dropout, kernel_size=9, act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+
+ self.pre_convs = nn.ModuleList()
+ self.pre_lns = nn.ModuleList()
+ for i in range(2):
+ self.pre_convs.append(TransformerFFNLayer(
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
+ self.pre_lns.append(LayerNorm(c))
+
+ self.layer_norm_attn = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(c, 1, encoder_decoder_attention=True, bias=False)
+
+ self.post_convs = nn.ModuleList()
+ self.post_lns = nn.ModuleList()
+ for i in range(8):
+ self.post_convs.append(TransformerFFNLayer(
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
+ self.post_lns.append(LayerNorm(c))
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ **kwargs,
+ ):
+ attn_logits = None
+ for conv, ln in zip(self.pre_convs, self.pre_lns):
+ residual = x
+ x = ln(x)
+ x = conv(x) + residual
+ if encoder_out is not None:
+ residual = x
+ x = self.layer_norm_attn(x)
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ )
+ attn_logits = attn[1]
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ for conv, ln in zip(self.post_convs, self.post_lns):
+ residual = x
+ x = ln(x)
+ x = conv(x) + residual
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
diff --git a/modules/audio2motion/transformer_models.py b/modules/audio2motion/transformer_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..05cc5ea196fb0af0f7cce8b8a41a2dcd5562f631
--- /dev/null
+++ b/modules/audio2motion/transformer_models.py
@@ -0,0 +1,208 @@
+from numpy import isin
+import torch
+import torch.nn as nn
+from modules.audio2motion.transformer_base import *
+
+DEFAULT_MAX_SOURCE_POSITIONS = 2000
+DEFAULT_MAX_TARGET_POSITIONS = 2000
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size
+ if kernel_size is not None else 9,
+ padding='SAME',
+ norm=norm, act='gelu'
+ )
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+######################
+# fastspeech modules
+######################
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None,
+ num_heads=2, use_pos_embed=True, use_last_norm=True, norm='ln',
+ use_pos_embed_alpha=True):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout if dropout is not None else 0.1
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
+ norm=norm)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ if norm == 'ln':
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ elif norm == 'bn':
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
+ elif norm == 'gn':
+ self.layer_norm = GroupNorm1DTBC(8, embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+class SequentialSA(nn.Module):
+ def __init__(self,layers):
+ super(SequentialSA,self).__init__()
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self,x,x_mask):
+ """
+ x: [batch, T, H]
+ x_mask: [batch, T]
+ """
+ pad_mask = 1. - x_mask
+ for layer in self.layers:
+ if isinstance(layer, EncSALayer):
+ x = x.permute(1,0,2)
+ x = layer(x,pad_mask)
+ x = x.permute(1,0,2)
+ elif isinstance(layer, nn.Linear):
+ x = layer(x) * x_mask.unsqueeze(2)
+ elif isinstance(layer, nn.AvgPool1d):
+ x = x.permute(0,2,1)
+ x = layer(x)
+ x = x.permute(0,2,1)
+ elif isinstance(layer, nn.PReLU):
+ bs, t, hid = x.shape
+ x = x.reshape([bs*t,hid])
+ x = layer(x)
+ x = x.reshape([bs, t, hid])
+ else: # Relu
+ x = layer(x)
+
+ return x
+
+class TransformerStyleFusionModel(nn.Module):
+ def __init__(self, num_heads=4, dropout = 0.1, out_dim = 64):
+ super(TransformerStyleFusionModel, self).__init__()
+ self.audio_layer = SequentialSA([
+ nn.Linear(29, 48),
+ nn.ReLU(48),
+ nn.Linear(48, 128),
+ ])
+
+ self.energy_layer = SequentialSA([
+ nn.Linear(1, 16),
+ nn.ReLU(16),
+ nn.Linear(16, 64),
+ ])
+
+ self.backbone1 = FFTBlocks(hidden_size=192,num_layers=3)
+
+ self.sty_encoder = nn.Sequential(*[
+ nn.Linear(135, 64),
+ nn.ReLU(),
+ nn.Linear(64, 128)
+ ])
+
+ self.backbone2 = FFTBlocks(hidden_size=320,num_layers=3)
+
+ self.out_layer = SequentialSA([
+ nn.AvgPool1d(kernel_size=2,stride=2,padding=0), #[b,hid,t_audio]=>[b,hid,t_audio//2]
+ nn.Linear(320,out_dim),
+ nn.PReLU(out_dim),
+ nn.Linear(out_dim,out_dim),
+ ])
+
+ self.dropout = nn.Dropout(p = dropout)
+
+ def forward(self, audio, energy, style, x_mask, y_mask):
+ pad_mask = 1. - x_mask
+ audio_feat = self.audio_layer(audio, x_mask)
+ energy_feat = self.energy_layer(energy, x_mask)
+ feat = torch.cat((audio_feat, energy_feat), dim=-1) # [batch, T, H=48+16]
+ feat = self.backbone1(feat, pad_mask)
+ feat = self.dropout(feat)
+
+ sty_feat = self.sty_encoder(style) # [batch,135]=>[batch, H=64]
+ sty_feat = sty_feat.unsqueeze(1).repeat(1, feat.shape[1], 1) # [batch, T, H=64]
+
+ feat = torch.cat([feat, sty_feat], dim=-1) # [batch, T, H=64+64]
+ feat = self.backbone2(feat, pad_mask) # [batch, T, H=128]
+ out = self.out_layer(feat, y_mask) # [batch, T//2, H=out_dim]
+
+ return out
+
+
+if __name__ == '__main__':
+ model = TransformerStyleFusionModel()
+ audio = torch.rand(4,200,29) # [B,T,H]
+ energy = torch.rand(4,200,1) # [B,T,H]
+ style = torch.ones(4,135) # [B,T]
+ x_mask = torch.ones(4,200) # [B,T]
+ x_mask[3,10:] = 0
+ ret = model(audio,energy,style, x_mask)
+ print(" ")
\ No newline at end of file
diff --git a/modules/audio2motion/utils.py b/modules/audio2motion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb56ec514bff822ba1a19a6474207ed82492410
--- /dev/null
+++ b/modules/audio2motion/utils.py
@@ -0,0 +1,29 @@
+import torch
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ t = (t // n_sqz) * n_sqz
+ x = x[:, :, :t]
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
+ else:
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+ else:
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_unsqz * x_mask, x_mask
diff --git a/modules/audio2motion/vae.py b/modules/audio2motion/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..9801ed631a6142297ce96d33c93ee508f32304b9
--- /dev/null
+++ b/modules/audio2motion/vae.py
@@ -0,0 +1,468 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+import torch.distributions as dist
+import numpy as np
+import copy
+from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
+from modules.audio2motion.transformer_base import Embedding
+
+from utils.commons.pitch_utils import f0_to_coarse
+from utils.commons.hparams import hparams
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e4) # an arbitrary large number
+
+class FVAEEncoder(nn.Module):
+ def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0, strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ if i == 0 else
+ nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
+
+ self.latent_channels = latent_channels
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ m, logs = torch.split(x, self.latent_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs))
+ return z, m, logs, x_mask
+
+
+class FVAEDecoder(nn.Module):
+ def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0,
+ strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
+ if i == 0 else
+ nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ return x
+
+class FVAE(nn.Module):
+ def __init__(self,
+ in_out_channels=64, hidden_channels=256, latent_size=16,
+ kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
+ use_prior_glow=True, glow_hidden=256, glow_kernel_size=3, glow_n_blocks=5,
+ sqz_prior=False, use_pos_emb=False):
+ super(FVAE, self).__init__()
+ self.in_out_channels = in_out_channels
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.latent_size = latent_size
+ self.use_prior_glow = use_prior_glow
+ self.sqz_prior = sqz_prior
+ self.g_pre_net = nn.Sequential(*[
+ nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.encoder = FVAEEncoder(in_out_channels, hidden_channels, latent_size, kernel_size,
+ enc_n_layers, gin_channels, strides=strides)
+ if use_prior_glow:
+ self.prior_flow = ResidualCouplingBlock(
+ latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
+ self.use_pos_embed = use_pos_emb
+ if sqz_prior:
+ self.query_proj = nn.Linear(latent_size, latent_size)
+ self.key_proj = nn.Linear(latent_size, latent_size)
+ self.value_proj = nn.Linear(latent_size, hidden_channels)
+ if self.in_out_channels in [7, 64]:
+ self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ elif self.in_out_channels == 71:
+ self.exp_decoder = FVAEDecoder(hidden_channels, hidden_channels, 64, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ self.pose_decoder = FVAEDecoder(hidden_channels, hidden_channels, 7, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ if self.use_pos_embed:
+ self.embed_positions = SinusoidalPositionalEmbedding(self.latent_size, 0,init_size=2000+1,)
+ else:
+ self.decoder = FVAEDecoder(latent_size, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+
+ self.prior_dist = dist.Normal(0, 1)
+
+ def forward(self, x=None, x_mask=None, g=None, infer=False, temperature=1. , **kwargs):
+ """
+
+ :param x: [B, T, C_in_out]
+ :param x_mask: [B, T]
+ :param g: [B, T, C_g]
+ :return:
+ """
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+ g = g.transpose(1,2) # [B, C_g, T]
+ g_for_sqz = g
+
+ g_sqz = self.g_pre_net(g_for_sqz)
+
+ if not infer:
+ x = x.transpose(1,2) # [B, C, T]
+ z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
+ if self.sqz_prior:
+ z = z_q
+ if self.use_pos_embed:
+ position = self.embed_positions(z.transpose(1,2).abs().sum(-1)).transpose(1,2)
+ z = z + position
+ q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
+ k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
+ v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
+ attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
+ attn = F.softmax(attn, dim=-1)
+ out = torch.bmm(attn, v) # [B, 1, C=256]
+ style_encoding = out.repeat([1,z_q.shape[-1],1]).transpose(1,2) # [B, C=256, T]
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(style_encoding, x_mask, g), self.pose_decoder(style_encoding, x_mask, g)], dim=1)
+ else:
+ x_recon = self.decoder(style_encoding, x_mask, g)
+ else:
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(z_q, x_mask, g), self.pose_decoder(z_q, x_mask, g)], dim=1)
+ else:
+ x_recon = self.decoder(z_q, x_mask, g)
+ q_dist = dist.Normal(m_q, logs_q.exp())
+ if self.use_prior_glow:
+ logqx = q_dist.log_prob(z_q)
+ z_p = self.prior_flow(z_q, x_mask_sqz, g_sqz)
+ logpx = self.prior_dist.log_prob(z_p)
+ loss_kl = ((logqx - logpx) * x_mask_sqz).sum() / x_mask_sqz.sum() / logqx.shape[1]
+ else:
+ loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist)
+ loss_kl = (loss_kl * x_mask_sqz).sum() / x_mask_sqz.sum() / z_q.shape[1]
+ z_p = z_q
+ return x_recon.transpose(1,2), loss_kl, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
+ else:
+ latent_shape = [g_sqz.shape[0], self.latent_size, g_sqz.shape[2]]
+ z_p = self.prior_dist.sample(latent_shape).to(g.device) * temperature # [B, latent_size, T_sqz]
+ if self.use_prior_glow:
+ z_p = self.prior_flow(z_p, 1, g_sqz, reverse=True)
+ if self.sqz_prior:
+ z = z_p
+ if self.use_pos_embed:
+ position = self.embed_positions(z.abs().sum(-1))
+ z += position
+ q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
+ k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
+ v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
+ attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
+ attn = F.softmax(attn, dim=-1)
+ out = torch.bmm(attn, v) # [B, 1, C=256]
+ style_encoding = out.repeat([1,z_p.shape[-1],1]).transpose(1,2) # [B, C=256, T]
+ x_recon = self.decoder(style_encoding, 1, g)
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(style_encoding, 1, g), self.pose_decoder(style_encoding, 1, g)], dim=1)
+ else:
+ x_recon = self.decoder(style_encoding, 1, g)
+ else:
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(z_p, 1, g), self.pose_decoder(z_p, 1, g)], dim=1)
+ else:
+ x_recon = self.decoder(z_p, 1, g)
+ return x_recon.transpose(1,2), z_p.transpose(1,2)
+
+
+class VAEModel(nn.Module):
+ def __init__(self, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
+ super().__init__()
+ feat_dim = 64
+ self.blink_embed = nn.Embedding(2, feat_dim)
+ self.audio_in_dim = audio_in_dim
+ cond_dim = feat_dim
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(audio_in_dim, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, feat_dim, 3, 1, 1, bias=False)
+ ])
+ self.cond_drop = cond_drop
+ if self.cond_drop:
+ self.dropout = nn.Dropout(0.5)
+
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.use_prior_flow = use_prior_flow
+ self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=cond_dim, strides=[4,],
+ use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='linear').transpose(1,2))
+ # self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
+ infer = not train
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['audio'].to(self.device)
+ mel = self.downsampler(mel)
+ cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+
+ if self.cond_drop:
+ cond_feat = self.dropout(cond_feat)
+
+ if not infer:
+ exp = batch['y'].to(self.device)
+ x = exp
+ x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_kl'] = loss_kl
+ if return_latent:
+ ret['m_q'] = m_q
+ ret['z_p'] = z_p
+ return x_recon, loss_kl, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+
+ return x_recon
+
+
+class PitchContourVAEModel(nn.Module):
+ def __init__(self, hparams, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
+ super().__init__()
+ self.hparams = copy.deepcopy(hparams)
+ feat_dim = 128
+ self.audio_in_dim = audio_in_dim
+ self.blink_embed = nn.Embedding(2, feat_dim)
+
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(audio_in_dim, feat_dim, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(feat_dim ),
+ nn.GELU(),
+ nn.Conv1d(feat_dim , feat_dim, 3, 1, 1, bias=False)
+ ])
+
+ self.pitch_embed = Embedding(300, feat_dim, None)
+ self.pitch_encoder = nn.Sequential(*[
+ nn.Conv1d(feat_dim, feat_dim , 3, 1, 1, bias=False),
+ nn.BatchNorm1d(feat_dim),
+ nn.GELU(),
+ nn.Conv1d(feat_dim, feat_dim, 3, 1, 1, bias=False)
+ ])
+
+ cond_dim = feat_dim + feat_dim + feat_dim
+
+ if hparams.get('use_mouth_amp_embed', False):
+ self.mouth_amp_embed = nn.Parameter(torch.randn(feat_dim))
+ cond_dim += feat_dim
+
+ if hparams.get('use_eye_amp_embed', False):
+ self.eye_amp_embed = nn.Parameter(torch.randn(feat_dim))
+ cond_dim += feat_dim
+
+ self.cond_proj = nn.Linear(cond_dim, feat_dim, bias=True)
+
+ self.cond_drop = cond_drop
+ if self.cond_drop:
+ self.dropout = nn.Dropout(0.5)
+
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.use_prior_flow = use_prior_flow
+ self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=feat_dim, strides=[4,],
+ use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
+ infer = not train
+ hparams = self.hparams
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['audio'].to(self.device)
+ f0 = batch['f0'].to(self.device) # [b,t]
+ if 'blink' not in batch:
+ batch['blink'] = torch.zeros([f0.shape[0], f0.shape[1], 1], dtype=torch.long, device=f0.device)
+ blink = batch['blink'].to(self.device)
+ blink_feat = self.blink_embed(blink.squeeze(2))
+
+ blink_feat = self.downsampler(blink_feat)
+ mel = self.downsampler(mel)
+ f0 = self.downsampler(f0.unsqueeze(-1)).squeeze(-1)
+ f0_coarse = f0_to_coarse(f0)
+ pitch_emb = self.pitch_embed(f0_coarse)
+ cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ pitch_feat = self.pitch_encoder(pitch_emb.transpose(1,2)).transpose(1,2)
+
+ cond_feats = [cond_feat, pitch_feat, blink_feat]
+ if hparams.get('use_mouth_amp_embed', False):
+ mouth_amp = batch.get('mouth_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
+ mouth_amp_feat = mouth_amp.unsqueeze(1) * self.mouth_amp_embed.unsqueeze(0)
+ mouth_amp_feat = mouth_amp_feat.repeat([1,cond_feat.shape[1],1])
+ cond_feats.append(mouth_amp_feat)
+
+ if hparams.get('use_eye_amp_embed', False):
+ eye_amp = batch.get('eye_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
+ eye_amp_feat = eye_amp.unsqueeze(1) * self.eye_amp_embed.unsqueeze(0)
+ eye_amp_feat = eye_amp_feat.repeat([1,cond_feat.shape[1],1])
+ cond_feats.append(eye_amp_feat)
+
+ cond_feat = torch.cat(cond_feats, dim=-1)
+ cond_feat = self.cond_proj(cond_feat)
+
+ if self.cond_drop:
+ cond_feat = self.dropout(cond_feat)
+
+ if not infer:
+ exp = batch['y'].to(self.device)
+ x = exp
+ x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_kl'] = loss_kl
+ if return_latent:
+ ret['m_q'] = m_q
+ ret['z_p'] = z_p
+ return x_recon, loss_kl, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+
+ return x_recon
+
+
+if __name__ == '__main__':
+ model = FVAE(in_out_channels=64, hidden_channels=128, latent_size=32,kernel_size=3, enc_n_layers=6, dec_n_layers=2,
+ gin_channels=80, strides=[4], use_prior_glow=False, glow_hidden=128, glow_kernel_size=3, glow_n_blocks=3)
+ x = torch.rand([8, 64, 1000])
+ x_mask = torch.ones([8,1,1000])
+ g = torch.rand([8, 80, 1000])
+ train_out = model(x,x_mask,g,infer=False)
+ x_recon, loss_kl, z_p, m_q, logs_q = train_out
+ print(" ")
+ infer_out = model(x,x_mask,g,infer=True)
+ x_recon, z_p = infer_out
+ print(" ")
diff --git a/modules/audio2motion/vqvae.py b/modules/audio2motion/vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..310ffc7bf4bf1c5a8c2901163439bba179a968fc
--- /dev/null
+++ b/modules/audio2motion/vqvae.py
@@ -0,0 +1,200 @@
+import scipy
+from scipy import linalg
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+from modules.audio2motion.transformer_models import FFTBlocks
+import modules.audio2motion.utils as utils
+from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
+import torch.distributions as dist
+from modules.audio2motion.cnn_models import LambdaLayer, LayerNorm
+
+from vector_quantize_pytorch import VectorQuantize
+
+
+class FVAEEncoder(nn.Module):
+ def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0, strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ if i == 0 else
+ nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
+ self.latent_channels = latent_channels
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ m, logs = torch.split(x, self.latent_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs))
+ return z, m, logs, x_mask
+
+
+class FVAEDecoder(nn.Module):
+ def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0,
+ strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
+ if i == 0 else
+ nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ return x
+
+
+class VQVAE(nn.Module):
+ def __init__(self,
+ in_out_channels=64, hidden_channels=256, latent_size=16,
+ kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
+ sqz_prior=False):
+ super().__init__()
+ self.in_out_channels = in_out_channels
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.latent_size = latent_size
+ self.g_pre_net = nn.Sequential(*[
+ nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.encoder = FVAEEncoder(in_out_channels, hidden_channels, hidden_channels, kernel_size,
+ enc_n_layers, gin_channels, strides=strides)
+ # if use_prior_glow:
+ # self.prior_flow = ResidualCouplingBlock(
+ # latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
+ self.vq = VectorQuantize(dim=hidden_channels, codebook_size=256, codebook_dim=16)
+
+ self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ self.prior_dist = dist.Normal(0, 1)
+ self.sqz_prior = sqz_prior
+
+ def forward(self, x=None, x_mask=None, g=None, infer=False, **kwargs):
+ """
+
+ :param x: [B, T, C_in_out]
+ :param x_mask: [B, T]
+ :param g: [B, T, C_g]
+ :return:
+ """
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+ g = g.transpose(1,2) # [B, C_g, T]
+ g_for_sqz = g
+
+ g_sqz = self.g_pre_net(g_for_sqz)
+
+ if not infer:
+ x = x.transpose(1,2) # [B, C, T]
+ z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
+ if self.sqz_prior:
+ z_q = F.interpolate(z_q, scale_factor=1/8)
+ z_p, idx, commit_loss = self.vq(z_q.transpose(1,2))
+ if self.sqz_prior:
+ z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
+
+ x_recon = self.decoder(z_p.transpose(1,2), x_mask, g)
+ return x_recon.transpose(1,2), commit_loss, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
+ else:
+ bs, t = g_sqz.shape[0], g_sqz.shape[2]
+ if self.sqz_prior:
+ t = t // 8
+ latent_shape = [int(bs * t)]
+ latent_idx = torch.randint(0,256,latent_shape).to(self.vq.codebook.device)
+ # latent_idx = torch.ones_like(latent_idx, dtype=torch.long)
+ # z_p = torch.gather(self.vq.codebook, 0, latent_idx)# self.vq.codebook[latent_idx]
+ z_p = self.vq.codebook[latent_idx]
+ z_p = z_p.reshape([bs, t, -1])
+ z_p = self.vq.project_out(z_p)
+ if self.sqz_prior:
+ z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
+
+ x_recon = self.decoder(z_p.transpose(1,2), 1, g)
+ return x_recon.transpose(1,2), z_p.transpose(1,2)
+
+
+class VQVAEModel(nn.Module):
+ def __init__(self, in_out_dim=71, sqz_prior=False, enc_no_cond=False):
+ super().__init__()
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(80, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, 64, 3, 1, 1, bias=False)
+ ])
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.enc_no_cond = enc_no_cond
+ self.vae = VQVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=64, strides=[4,], sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, log_dict=None, train=True):
+ infer = not train
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['mel'].to(self.device)
+ mel = self.downsampler(mel)
+
+ mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ if not infer:
+ exp = batch['exp'].to(self.device)
+ pose = batch['pose'].to(self.device)
+ if self.in_dim == 71:
+ x = torch.cat([exp, pose], dim=-1) # [B, T, C=64 + 7]
+ elif self.in_dim == 64:
+ x = exp
+ elif self.in_dim == 7:
+ x = pose
+ if self.enc_no_cond:
+ x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=torch.zeros_like(mel_feat), infer=False)
+ else:
+ x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=mel_feat, infer=False)
+ loss_commit = loss_commit.reshape([])
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_commit'] = loss_commit
+ return x_recon, loss_commit, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=mel_feat, infer=True)
+ return x_recon
+
+ # def __get_feat(self, exp, pose):
+ # diff_exp = exp[:-1, :] - exp[1:, :]
+ # exp_std = (np.std(exp, axis = 0) - self.exp_std_mean) / self.exp_std_std
+ # diff_exp_std = (np.std(diff_exp, axis = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std
+
+ # diff_pose = pose[:-1, :] - pose[1:, :]
+ # diff_pose_std = (np.std(diff_pose, axis = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std
+
+ # return np.concatenate((exp_std, diff_exp_std, diff_pose_std))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
diff --git a/modules/commons/attention/attentions.py b/modules/commons/attention/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b2b5bc03732ff17a0cb135e977fbe526dff3341
--- /dev/null
+++ b/modules/commons/attention/attentions.py
@@ -0,0 +1,427 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+import numpy as np
+from typing import Optional, Tuple
+
+
+class ScaledDotProductAttention(nn.Module):
+ """
+ Scaled Dot-Product Attention proposed in "Attention Is All You Need"
+ Compute the dot products of the query with all keys, divide each by sqrt(dim),
+ and apply a softmax function to obtain the weights on the values
+ Args: dim, mask
+ dim (int): dimention of attention
+ mask (torch.Tensor): tensor containing indices to be masked
+ Inputs: query, key, value, mask
+ - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
+ - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
+ - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
+ - **mask** (-): tensor containing indices to be masked
+ Returns: context, attn
+ - **context**: tensor containing the context vector from attention mechanism.
+ - **attn**: tensor containing the attention (alignment) from the encoder outputs.
+ """
+ def __init__(self, dim: int):
+ super(ScaledDotProductAttention, self).__init__()
+ self.sqrt_dim = np.sqrt(dim)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
+ score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
+
+ if mask is not None:
+ score.masked_fill_(mask.view(score.size()), -float('Inf'))
+
+ attn = F.softmax(score, -1)
+ context = torch.bmm(attn, value)
+ return context, attn
+
+
+class DotProductAttention(nn.Module):
+ """
+ Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
+ """
+ def __init__(self, hidden_dim):
+ super(DotProductAttention, self).__init__()
+
+ def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
+
+ score = torch.bmm(query, value.transpose(1, 2))
+ attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
+ context = torch.bmm(attn, value)
+
+ return context, attn
+
+
+class AdditiveAttention(nn.Module):
+ """
+ Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
+ Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
+ Args:
+ hidden_dim (int): dimesion of hidden state vector
+ Inputs: query, value
+ - **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ Returns: context, attn
+ - **context**: tensor containing the context vector from attention mechanism.
+ - **attn**: tensor containing the alignment from the encoder outputs.
+ Reference:
+ - **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
+ """
+ def __init__(self, hidden_dim: int) -> None:
+ super(AdditiveAttention, self).__init__()
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
+ self.score_proj = nn.Linear(hidden_dim, 1)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
+ score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
+ attn = F.softmax(score, dim=-1)
+ context = torch.bmm(attn.unsqueeze(1), value)
+ return context, attn
+
+
+class LocationAwareAttention(nn.Module):
+ """
+ Applies a location-aware attention mechanism on the output features from the decoder.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ The location-aware attention mechanism is performing well in speech recognition tasks.
+ We refer to implementation of ClovaCall Attention style.
+ Args:
+ hidden_dim (int): dimesion of hidden state vector
+ smoothing (bool): flag indication whether to use smoothing or not.
+ Inputs: query, value, last_attn, smoothing
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ Reference:
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ - **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
+ """
+ def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
+ super(LocationAwareAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
+ self.smoothing = smoothing
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
+
+ # Initialize previous attention (alignment) to zeros
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size, seq_len)
+
+ conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
+ score = self.score_proj(torch.tanh(
+ self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
+ + self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
+ + conv_attn
+ + self.bias
+ )).squeeze(dim=-1)
+
+ if self.smoothing:
+ score = torch.sigmoid(score)
+ attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
+ else:
+ attn = F.softmax(score, dim=-1)
+
+ context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
+
+ return context, attn
+
+
+class MultiHeadLocationAwareAttention(nn.Module):
+ """
+ Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ The location-aware attention mechanism is performing well in speech recognition tasks.
+ In the above paper applied a signle head, but we applied multi head concept.
+ Args:
+ hidden_dim (int): The number of expected features in the output
+ num_heads (int): The number of heads. (default: )
+ conv_out_channel (int): The number of out channel in convolution
+ Inputs: query, value, prev_attn
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ Reference:
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ """
+ def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
+ super(MultiHeadLocationAwareAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.dim = int(hidden_dim / num_heads)
+ self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.score_proj = nn.Linear(self.dim, 1, bias=True)
+ self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, seq_len = value.size(0), value.size(1)
+
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
+
+ loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
+ loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ query = query.contiguous().view(-1, 1, self.dim)
+ value = value.contiguous().view(-1, seq_len, self.dim)
+
+ score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
+ attn = F.softmax(score, dim=1)
+
+ value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ value = value.contiguous().view(-1, seq_len, self.dim)
+
+ context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
+ attn = attn.view(batch_size, self.num_heads, -1)
+
+ return context, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ Multi-Head Attention proposed in "Attention Is All You Need"
+ Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
+ project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
+ These are concatenated and once again projected, resulting in the final values.
+ Multi-head attention allows the model to jointly attend to information from different representation
+ subspaces at different positions.
+ MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
+ where head_i = Attention(Q · W_q, K · W_k, V · W_v)
+ Args:
+ d_model (int): The dimension of keys / values / quries (default: 512)
+ num_heads (int): The number of attention heads. (default: 8)
+ Inputs: query, key, value, mask
+ - **query** (batch, q_len, d_model): In transformer, three different ways:
+ Case 1: come from previoys decoder layer
+ Case 2: come from the input embedding
+ Case 3: come from the output embedding (masked)
+ - **key** (batch, k_len, d_model): In transformer, three different ways:
+ Case 1: come from the output of the encoder
+ Case 2: come from the input embeddings
+ Case 3: come from the output embedding (masked)
+ - **value** (batch, v_len, d_model): In transformer, three different ways:
+ Case 1: come from the output of the encoder
+ Case 2: come from the input embeddings
+ Case 3: come from the output embedding (masked)
+ - **mask** (-): tensor containing indices to be masked
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features.
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ """
+ def __init__(self, d_model: int = 512, num_heads: int = 8):
+ super(MultiHeadAttention, self).__init__()
+
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
+
+ self.d_head = int(d_model / num_heads)
+ self.num_heads = num_heads
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
+ self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
+ self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
+ self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ batch_size = value.size(0)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
+
+ query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
+ key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
+ value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
+
+ if mask is not None:
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
+
+ context, attn = self.scaled_dot_attn(query, key, value, mask)
+
+ context = context.view(self.num_heads, batch_size, -1, self.d_head)
+ context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
+
+ return context, attn
+
+
+class RelativeMultiHeadAttention(nn.Module):
+ """
+ Multi-head attention with relative positional encoding.
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Args:
+ d_model (int): The dimension of model
+ num_heads (int): The number of attention heads.
+ dropout_p (float): probability of dropout
+ Inputs: query, key, value, pos_embedding, mask
+ - **query** (batch, time, dim): Tensor containing query vector
+ - **key** (batch, time, dim): Tensor containing key vector
+ - **value** (batch, time, dim): Tensor containing value vector
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
+ Returns:
+ - **outputs**: Tensor produces by relative multi head attention module.
+ """
+ def __init__(
+ self,
+ d_model: int = 512,
+ num_heads: int = 16,
+ dropout_p: float = 0.1,
+ ):
+ super(RelativeMultiHeadAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
+ self.d_model = d_model
+ self.d_head = int(d_model / num_heads)
+ self.num_heads = num_heads
+ self.sqrt_dim = math.sqrt(d_model)
+
+ self.query_proj = nn.Linear(d_model, d_model)
+ self.key_proj = nn.Linear(d_model, d_model)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.pos_proj = nn.Linear(d_model, d_model, bias=False)
+
+ self.dropout = nn.Dropout(p=dropout_p)
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ torch.nn.init.xavier_uniform_(self.u_bias)
+ torch.nn.init.xavier_uniform_(self.v_bias)
+
+ self.out_proj = nn.Linear(d_model, d_model)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_embedding: Tensor,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ batch_size = value.size(0)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
+
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
+ pos_score = self._compute_relative_positional_encoding(pos_score)
+
+ score = (content_score + pos_score) / self.sqrt_dim
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ score.masked_fill_(mask, -1e9)
+
+ attn = F.softmax(score, -1)
+ attn = self.dropout(attn)
+
+ context = torch.matmul(attn, value).transpose(1, 2)
+ context = context.contiguous().view(batch_size, -1, self.d_model)
+
+ return self.out_proj(context)
+
+ def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
+
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
+
+ return pos_score
+
+
+class CustomizingAttention(nn.Module):
+ r"""
+ Customizing Attention
+ Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
+ Multi-head attention proposed in "Attention Is All You Need" paper.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ I combined these two attention mechanisms as custom.
+ Args:
+ hidden_dim (int): The number of expected features in the output
+ num_heads (int): The number of heads. (default: )
+ conv_out_channel (int): The dimension of convolution
+ Inputs: query, value, last_attn
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
+ - **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
+ Reference:
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ """
+
+ def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
+ super(CustomizingAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.dim = int(hidden_dim / num_heads)
+ self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
+ self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
+ self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
+
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
+
+ loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
+
+ query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
+ value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
+
+ query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
+ value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
+ query = query.contiguous().view(-1, q_len, self.dim)
+ value = value.contiguous().view(-1, v_len, self.dim)
+
+ context, attn = self.scaled_dot_attn(query, value)
+ attn = attn.squeeze()
+
+ context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
+ context = context.contiguous().view(batch_size, q_len, -1)
+
+ return context, attn
+
+ def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
+ conv_feat = self.conv1d(last_attn.unsqueeze(1))
+ conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
+
+ loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
+ loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
+
+ return loc_energy
\ No newline at end of file
diff --git a/modules/commons/attention/simple_attention.py b/modules/commons/attention/simple_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c451ce9324491a5c9fa8546b0fe98dc146c6c1
--- /dev/null
+++ b/modules/commons/attention/simple_attention.py
@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def split_heads(x, num_heads):
+ """ Split heads
+ :param x: A tensor with shape [batch, length, channels]
+ :param num_heads: An integer
+ :returns: A tensor with shape [batch, heads, length, channels / heads]
+ """
+ assert x.shape[-1] % num_heads == 0, str(x.shape)
+ return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3)
+
+
+def combine_heads(x):
+ """ Combine heads
+ :param x: A tensor with shape [batch, heads, length, channels]
+ :returns: A tensor with shape [batch, length, heads * channels]
+ """
+ x = x.permute([0, 2, 1, 3])
+ return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],))
+
+
+class SimpleAttention(nn.Module):
+ def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1):
+ super(SimpleAttention, self).__init__()
+ self.q_transform = nn.Linear(query_size, query_size, bias=False)
+ self.k_transform = nn.Linear(key_size, query_size, bias=False)
+ self.v_transform = nn.Linear(value_size, query_size, bias=False)
+ self.output_transform = nn.Linear(query_size, query_size, bias=False)
+ self.query_size = query_size
+ self.key_size = key_size
+ self.value_size = value_size
+ self.num_heads = num_heads
+
+ def forward(self, query, key, value, attn_mask=None, bias=None):
+ q = self.q_transform(query)
+ k = self.k_transform(key)
+ v = self.v_transform(value)
+
+ logits = torch.bmm(q, k.transpose(1, 2)) # [batch, length_q, length_k]
+ if bias is not None:
+ logits += bias
+ if attn_mask is not None:
+ logits = logits + attn_mask * -1e9
+ weights = F.softmax(logits, dim=-1)
+ out = torch.bmm(weights, v)
+ out = self.output_transform(out)
+ return out, weights
diff --git a/modules/commons/conformer/conformer.py b/modules/commons/conformer/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b5d719a7b67ef745f178cf44e8c452191a3a2a
--- /dev/null
+++ b/modules/commons/conformer/conformer.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nn
+from .espnet_positional_embedding import RelPositionalEncoding
+from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
+from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
+from ..layers import Embedding
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class ConformerLayers(nn.Module):
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4, use_last_norm=True):
+ super().__init__()
+ self.use_last_norm = use_last_norm
+ self.layers = nn.ModuleList()
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
+ hidden_size,
+ MultiHeadedAttention(num_heads, hidden_size, 0.0),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
+ dropout,
+ ) for _ in range(num_layers)])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ else:
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
+
+ def forward(self, x, x_mask):
+ """
+
+ :param x: [B, T, H]
+ :param padding_mask: [B, T]
+ :return: [B, T, H]
+ """
+ for l in self.encoder_layers:
+ x, mask = l(x, x_mask)
+ x = self.layer_norm(x) * x_mask
+ return x
+
+
+class ConformerEncoder(ConformerLayers):
+ def __init__(self, hidden_size, dict_size=0, in_size=0, strides=[2,2], num_layers=None):
+ conformer_enc_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
+ self.dict_size = dict_size
+ if dict_size != 0:
+ self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
+ else:
+ self.seq_proj_in = torch.nn.Linear(in_size, hidden_size)
+ self.seq_proj_out = torch.nn.Linear(hidden_size, in_size)
+ self.mel_in = torch.nn.Linear(160, hidden_size)
+ self.mel_pre_net = torch.nn.Sequential(*[
+ torch.nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+
+ def forward(self, seq_out, mels_timbre, other_embeds=0):
+ """
+
+ :param src_tokens: [B, T]
+ :return: [B x T x C]
+ """
+ x_lengths = (seq_out > 0).long().sum(-1)
+ x = seq_out
+ if self.dict_size != 0:
+ x = self.embed(x) + other_embeds # [B, T, H]
+ else:
+ x = self.seq_proj_in(x) + other_embeds # [B, T, H]
+ mels_timbre = self.mel_in(mels_timbre).transpose(1, 2)
+ mels_timbre = self.mel_pre_net(mels_timbre).transpose(1, 2)
+
+ T_out = x.size(1)
+ if self.dict_size != 0:
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths + mels_timbre.size(1), x.size(1) + mels_timbre.size(1)), 2).to(x.dtype)
+ else:
+ x_mask = torch.cat((torch.ones(x.size(0), mels_timbre.size(1), 1).to(x.device), (x.abs().sum(2) > 0).float()[:, :, None]), dim=1)
+ x = torch.cat((mels_timbre, x), 1)
+ x = super(ConformerEncoder, self).forward(x, x_mask)
+ if self.dict_size != 0:
+ x = x[:, -T_out:, :]
+ else:
+ x = self.seq_proj_out(x[:, -T_out:, :])
+ return x
+
+
+class ConformerDecoder(ConformerLayers):
+ def __init__(self, hidden_size, num_layers):
+ conformer_dec_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
diff --git a/modules/commons/conformer/espnet_positional_embedding.py b/modules/commons/conformer/espnet_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b9b5549cc779d1ea67f052b1c99cad92365503
--- /dev/null
+++ b/modules/commons/conformer/espnet_positional_embedding.py
@@ -0,0 +1,113 @@
+import math
+import torch
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ """Construct an PositionalEncoding object."""
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class ScaledPositionalEncoding(PositionalEncoding):
+ """Scaled positional encoding module.
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.alpha.data = torch.tensor(1.0)
+
+ def forward(self, x):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self, x):
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[:, : x.size(1)]
+ return self.dropout(x), self.dropout(pos_emb)
\ No newline at end of file
diff --git a/modules/commons/conformer/espnet_transformer_attn.py b/modules/commons/conformer/espnet_transformer_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..48a52aacbaf07ef191c28baf12123036c2bc6b10
--- /dev/null
+++ b/modules/commons/conformer/espnet_transformer_attn.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ if not self.flash:
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ B, Nh, Nt, E = q.shape
+ q = q / math.sqrt(E)
+ mask = mask * mask[:, None, :, 0]
+ mask = mask[:, None]
+ if self.flash:
+ attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False, attn_mask=mask)
+ else:
+ attn = self.slow_attn(q, k, v, is_causal=False, attn_mask=mask)
+ attn = attn.transpose(1, 2)
+ attn = attn.reshape(B, -1, self.h * self.d_k)
+ attn = self.linear_out(attn)
+ return attn
+
+ def slow_attn(self, Q, K, V, is_causal, attn_mask):
+ attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask
+ attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1)
+ return attn_weight @ V
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional ecoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x, zero_triu=False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
diff --git a/modules/commons/conformer/layers.py b/modules/commons/conformer/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd7f501667e0b8aa816373d843adc816748e73a8
--- /dev/null
+++ b/modules/commons/conformer/layers.py
@@ -0,0 +1,260 @@
+from torch import nn
+import torch
+
+from modules.commons.layers import LayerNorm
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ """
+
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(self, x):
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x)
+
+ return x.transpose(1, 2)
+
+
+class MultiLayeredConv1d(torch.nn.Module):
+ """Multi-layered conv1d for Transformer block.
+ This is a module of multi-leyered conv1d designed
+ to replace positionwise feed-forward network
+ in Transforner block, which is introduced in
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
+ https://arxiv.org/pdf/1905.09263.pdf
+ """
+
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
+ """Initialize MultiLayeredConv1d module.
+ Args:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ kernel_size (int): Kernel size of conv1d.
+ dropout_rate (float): Dropout rate.
+ """
+ super(MultiLayeredConv1d, self).__init__()
+ self.w_1 = torch.nn.Conv1d(
+ in_chans,
+ hidden_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.w_2 = torch.nn.Conv1d(
+ hidden_chans,
+ in_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ def forward(self, x):
+ """Calculate forward propagation.
+ Args:
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+ Returns:
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
+ """
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x):
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = LayerNorm(size) # for the FNN module
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = LayerNorm(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = LayerNorm(size) # for the CNN module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x = residual + self.dropout(self.conv_module(x))
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
diff --git a/modules/commons/conv.py b/modules/commons/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a601f06042c2db37ace11ce72149101a9b8aefe4
--- /dev/null
+++ b/modules/commons/conv.py
@@ -0,0 +1,198 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules.commons.layers import LayerNorm, Embedding
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+def init_weights_func(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv1d") != -1:
+ torch.nn.init.xavier_uniform_(m.weight)
+
+
+class ResidualBlock(nn.Module):
+ """Implements conv->PReLU->norm n-times"""
+
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
+ c_multiple=2, ln_eps=1e-12, left_pad=False):
+ super(ResidualBlock, self).__init__()
+
+ if norm_type == 'bn':
+ norm_builder = lambda: nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm_builder = lambda: nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
+ else:
+ norm_builder = lambda: nn.Identity()
+
+ if left_pad:
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.ConstantPad1d(((dilation * (kernel_size - 1)) // 2 * 2, 0), 0),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, padding=0),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
+ )
+ for i in range(n)
+ ]
+ else:
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2, padding_mode='reflect'),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
+ )
+ for i in range(n)
+ ]
+
+ self.blocks = nn.ModuleList(self.blocks)
+ self.dropout = dropout
+
+ def forward(self, x):
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ for b in self.blocks:
+ x_ = b(x)
+ if self.dropout > 0 and self.training:
+ x_ = F.dropout(x_, self.dropout, training=self.training)
+ x = x + x_
+ x = x * nonpadding
+ return x
+
+
+class ConvBlocks(nn.Module):
+ """Decodes the expanded phoneme encoding into spectrograms"""
+
+ def __init__(self, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5,
+ init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3,
+ left_pad=False, c_in=None):
+ super(ConvBlocks, self).__init__()
+ self.is_BTC = is_BTC
+ if num_layers is not None:
+ dilations = [1] * num_layers
+ self.res_blocks = nn.Sequential(
+ *[ResidualBlock(hidden_size, kernel_size, d,
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
+ dropout=dropout, ln_eps=ln_eps, left_pad=left_pad)
+ for d in dilations],
+ )
+ if norm_type == 'bn':
+ norm = nn.BatchNorm1d(hidden_size)
+ elif norm_type == 'in':
+ norm = nn.InstanceNorm1d(hidden_size, affine=True)
+ elif norm_type == 'gn':
+ norm = nn.GroupNorm(8, hidden_size)
+ elif norm_type == 'ln':
+ norm = LayerNorm(hidden_size, dim=1, eps=ln_eps)
+ self.last_norm = norm
+ if left_pad:
+ self.post_net1 = nn.Sequential(
+ nn.ConstantPad1d((post_net_kernel // 2 * 2, 0), 0),
+ nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel, padding=0),
+ )
+ else:
+ self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
+ padding=post_net_kernel // 2, padding_mode='reflect')
+ self.c_in = c_in
+ if c_in is not None:
+ self.in_conv = nn.Conv1d(c_in, hidden_size, kernel_size=1, padding_mode='reflect')
+ if init_weights:
+ self.apply(init_weights_func)
+
+ def forward(self, x, nonpadding=None):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ if self.c_in is not None:
+ x = self.in_conv(x)
+ if nonpadding is None:
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ elif self.is_BTC:
+ nonpadding = nonpadding.transpose(1, 2)
+ x = self.res_blocks(x) * nonpadding
+ x = self.last_norm(x) * nonpadding
+ x = self.post_net1(x) * nonpadding
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+
+class TextConvEncoder(ConvBlocks):
+ def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
+ super().__init__(hidden_size, out_dims, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, num_layers=num_layers,
+ post_net_kernel=post_net_kernel)
+ self.dict_size = dict_size
+ if dict_size > 0:
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+
+ def forward(self, txt_tokens, other_embeds=0):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ if self.dict_size > 0:
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ else:
+ x = txt_tokens
+ x = x + other_embeds
+ return super().forward(x, nonpadding=(txt_tokens > 0).float()[..., None])
+
+
+class ConditionalConvBlocks(ConvBlocks):
+ def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
+ super().__init__(hidden_size, c_out, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
+ self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1, padding_mode='reflect')
+ self.is_BTC_ = is_BTC
+ if init_weights:
+ self.g_prenet.apply(init_weights_func)
+
+ def forward(self, x, cond, nonpadding=None):
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2)
+ if nonpadding is not None:
+ nonpadding = nonpadding.transpose(1, 2)
+ if nonpadding is None:
+ nonpadding = x.abs().sum(1)[:, None]
+ x = x + self.g_prenet(cond)
+ x = x * nonpadding
+ x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ return x
diff --git a/modules/commons/gpt.py b/modules/commons/gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e40349d0fae65107206033583d2cdc55289d09
--- /dev/null
+++ b/modules/commons/gpt.py
@@ -0,0 +1,474 @@
+import math
+import torch
+from typing import Optional, Tuple
+from torch import nn
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+# from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
+
+DEFAULT_MAX_SOURCE_POSITIONS = 20000
+DEFAULT_MAX_TARGET_POSITIONS = 20000
+
+
+class RotaryEmbeddings(nn.Module):
+ cos: torch.Tensor
+ sin: torch.Tensor
+ theta: torch.Tensor
+
+ def __init__(
+ self,
+ width: int,
+ *,
+ seq_len: int = 4000,
+ base: int = 10000,
+ device: Optional[torch.device] = None,
+ ):
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
+ will be precomputed for up to 'seq _len' positions. The embedding
+ will be recomputed when a longer sequence is found in the input.
+
+ :param width:
+ Rotary embedding dimensionality, must be even.
+ :param seq_len:
+ Number of positons to initially precompute.
+ :param base:
+ The base used for Θ_i, determines the cycle length of the
+ embeddings.
+ :param device: Device on which the module is to be initialized.
+ """
+ super().__init__()
+
+ if width % 2:
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
+
+ # Ignore allocations on the meta device as we don't persist our buffer,
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
+ if device is not None and device.type == "meta":
+ device = None
+ # Θ_i = 10000^(-2(i-1)/d)
+ theta = torch.pow(
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
+ )
+ self.register_buffer("theta", theta, persistent=False)
+
+ self._create_rotary_embed(width=width, length=seq_len)
+
+ def _create_rotary_embed(self, *, width: int, length: int):
+ # mΘ
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
+ m_theta = position * self.theta.unsqueeze(0)
+
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
+ # is changed for compatibility with most common implementations.
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
+
+ re_cos = m_theta.cos().view([length, width]).half()
+ re_sin = m_theta.sin().view([length, width]).half()
+
+ self.register_buffer("cos", re_cos, persistent=False)
+ self.register_buffer("sin", re_sin, persistent=False)
+
+ def _rotate(self, input: torch.Tensor):
+ """Rotate the input tensor by half of its innermost width.
+
+ input (Tensor): array to rotate.
+ RETURNS (Tensor): rotated array.
+
+ Shapes:
+ input - (..., width)
+ output - (..., width)
+ """
+ half_idx = input.shape[-1] // 2
+ input_1 = -input[..., half_idx:]
+ input_2 = input[..., :half_idx]
+ return torch.cat([input_1, input_2], dim=-1)
+
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
+ """
+ Apply rotary embeddings to an array.
+
+ :param input: Array to apply the rotary embeddings to.
+ :param positions: positions of the inputs. If no positions are
+ provided, they are assumed to be [0, seq_len).
+ :return: Array with the rotary embeddings applied.
+
+ Shapes:
+ input - (batch_size, num_heads, seq_len, width_per_head)
+ positions - (batch_size, seq_len)
+ output - (batch_size, num_heads, seq_len, width_per_head)
+ """
+ batch_size, _, seq_len, width = input.shape
+
+ if positions is None:
+ # Fastpath: positions from [0..seq_len), avoid indexing.
+ if self.cos.size(-2) < seq_len:
+ self._create_rotary_embed(width=width, length=seq_len)
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
+ else:
+ max_len = int(positions.max()) + 1
+ if self.cos.size(-2) < max_len:
+ self._create_rotary_embed(width=width, length=max_len)
+
+ # Flatten positions to index cos/sin arrays, then unflatten.
+ #
+ # Example shapes:
+ #
+ # positions_flat - (batch_size * seq_len)
+ # self.cos - (max_len, width)
+ # rot_cos - (batch_size, seq_len, width)
+ positions_flat = positions.view(-1)
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
+
+ # Eq 34 with ordering changed for compatibility.
+ return rot_cos * input + rot_sin * self._rotate(input)
+
+
+class LayerNorm(nn.Module):
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
+
+ def __init__(self, ndim, bias=False):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(ndim))
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+
+ def forward(self, input):
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, dropout=0.):
+ super().__init__()
+ # Typically, bias = True in Linears and LayerNorms, like GPT-2. But we set bias = False: a bit better and faster (following https://github.com/karpathy/nanoGPT)
+ assert embed_dim % num_heads == 0
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+ # key, query, value projections for all heads, but in a batch
+ self.c_attn = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
+ # output projection
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ # rotary embeddings
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ if not self.flash:
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ else:
+ saved_state = None
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ q, k, v = self.c_attn(query).split(self.embed_dim, dim=2)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ if incremental_state is not None:
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ else:
+ key_pos = spk_pos_ids_flat
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # Start Attention
+ if self.flash:
+ # efficient attention using Flash Attention CUDA kernels
+ attn = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=0,
+ is_causal=False)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+
+ # Flash Attn 2
+ # from flash_attn import flash_attn_func
+ # q, k, v = q.transpose(0, 1)[None, :], k.transpose(0, 1)[None, :], v.transpose(0, 1)[None, :]
+ # attn = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)[0].contiguous().view(tgt_len, bsz, embed_dim)
+
+ attn = self.out_proj(attn)
+ attn_logits = None
+ else:
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2, bias=False)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size, bias=False)
+ )
+ self.ffn_2 = nn.Linear(filter_size, hidden_size, bias=False)
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ T_inp = x.shape[0]
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-T_inp:]
+ # if self.act == 'gelu':
+ # x = F.gelu(x)
+ # if self.act == 'relu':
+ # x = F.relu(x)
+ x = F.silu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class GPTBlock(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, norm_cls=LayerNorm):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = norm_cls(c)
+ self.self_attn = CausalSelfAttention(
+ c, num_heads, dropout=attention_dropout
+ )
+ self.layer_norm2 = norm_cls(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ attn_out=None,
+ spk_pos_ids_flat=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask,
+ spk_pos_ids_flat=spk_pos_ids_flat,
+ need_weights=False
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ attn_logits = None
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class GPTLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
+ lm_num_layers=10, norm_cls=LayerNorm):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = GPTBlock(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln, norm_cls=norm_cls)
+
+ # init all weights
+ self.apply(self._init_weights)
+ # apply special scaled init to the residual projections, per GPT-2 paper
+ for pn, p in self.named_parameters():
+ if pn.endswith('ffn_2.weight') or pn.endswith('out_proj.weight'):
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * lm_num_layers))
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+
+ @torch.autocast(device_type='cuda')
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
diff --git a/modules/commons/improved_diffusion/__init__.py b/modules/commons/improved_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9665a0d63f695eab303318d824dad14041c7cde9
--- /dev/null
+++ b/modules/commons/improved_diffusion/__init__.py
@@ -0,0 +1,3 @@
+"""
+Codebase for "Improved Denoising Diffusion Probabilistic Models".
+"""
diff --git a/modules/commons/improved_diffusion/dist_util.py b/modules/commons/improved_diffusion/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f665604d6baaf5df6008f131c86cf0779c8b208a
--- /dev/null
+++ b/modules/commons/improved_diffusion/dist_util.py
@@ -0,0 +1,82 @@
+"""
+Helpers for distributed training.
+"""
+
+import io
+import os
+import socket
+
+import blobfile as bf
+from mpi4py import MPI
+import torch as th
+import torch.distributed as dist
+
+# Change this to reflect your cluster layout.
+# The GPU for a given rank is (rank % GPUS_PER_NODE).
+GPUS_PER_NODE = 8
+
+SETUP_RETRY_COUNT = 3
+
+
+def setup_dist():
+ """
+ Setup a distributed process group.
+ """
+ if dist.is_initialized():
+ return
+
+ comm = MPI.COMM_WORLD
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
+
+ if backend == "gloo":
+ hostname = "localhost"
+ else:
+ hostname = socket.gethostbyname(socket.getfqdn())
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
+ os.environ["RANK"] = str(comm.rank)
+ os.environ["WORLD_SIZE"] = str(comm.size)
+
+ port = comm.bcast(_find_free_port(), root=0)
+ os.environ["MASTER_PORT"] = str(port)
+ dist.init_process_group(backend=backend, init_method="env://")
+
+
+def dev():
+ """
+ Get the device to use for torch.distributed.
+ """
+ if th.cuda.is_available():
+ return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
+ return th.device("cpu")
+
+
+def load_state_dict(path, **kwargs):
+ """
+ Load a PyTorch file without redundant fetches across MPI ranks.
+ """
+ if MPI.COMM_WORLD.Get_rank() == 0:
+ with bf.BlobFile(path, "rb") as f:
+ data = f.read()
+ else:
+ data = None
+ data = MPI.COMM_WORLD.bcast(data)
+ return th.load(io.BytesIO(data), **kwargs)
+
+
+def sync_params(params):
+ """
+ Synchronize a sequence of Tensors across ranks from rank 0.
+ """
+ for p in params:
+ with th.no_grad():
+ dist.broadcast(p, 0)
+
+
+def _find_free_port():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+ finally:
+ s.close()
diff --git a/modules/commons/improved_diffusion/fp16_util.py b/modules/commons/improved_diffusion/fp16_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e0418153143200a718f56077b3360f30f4c663
--- /dev/null
+++ b/modules/commons/improved_diffusion/fp16_util.py
@@ -0,0 +1,76 @@
+"""
+Helpers to train with 16-bit precision.
+"""
+
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.half()
+ l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.float()
+ l.bias.data = l.bias.data.float()
+
+
+def make_master_params(model_params):
+ """
+ Copy model parameters into a (differently-shaped) list of full-precision
+ parameters.
+ """
+ master_params = _flatten_dense_tensors(
+ [param.detach().float() for param in model_params]
+ )
+ master_params = nn.Parameter(master_params)
+ master_params.requires_grad = True
+ return [master_params]
+
+
+def model_grads_to_master_grads(model_params, master_params):
+ """
+ Copy the gradients from the model parameters into the master parameters
+ from make_master_params().
+ """
+ master_params[0].grad = _flatten_dense_tensors(
+ [param.grad.data.detach().float() for param in model_params]
+ )
+
+
+def master_params_to_model_params(model_params, master_params):
+ """
+ Copy the master parameter data back into the model parameters.
+ """
+ # Without copying to a list, if a generator is passed, this will
+ # silently not copy any parameters.
+ model_params = list(model_params)
+
+ for param, master_param in zip(
+ model_params, unflatten_master_params(model_params, master_params)
+ ):
+ param.detach().copy_(master_param)
+
+
+def unflatten_master_params(model_params, master_params):
+ """
+ Unflatten the master parameters to look like model_params.
+ """
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
+
+
+def zero_grad(model_params):
+ for param in model_params:
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad.zero_()
diff --git a/modules/commons/improved_diffusion/gaussian_diffusion.py b/modules/commons/improved_diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e76eafab7a071e14b92821dbe0d8fd4382bdccd
--- /dev/null
+++ b/modules/commons/improved_diffusion/gaussian_diffusion.py
@@ -0,0 +1,870 @@
+"""
+This code started out as a PyTorch port of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+
+Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
+"""
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+from .nn import mean_flat
+from .losses import normal_kl, discretized_gaussian_log_likelihood
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ )
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def p_sample(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ if 'sample_merge' in final:
+ return final["sample_merge"]
+ else:
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ # mask = model_kwargs['mask']
+ # img = out["sample"] * mask
+ # if model_kwargs.get('replace_val') is not None:
+ # replace_idx = model_kwargs['replace_idx']
+ # replace_val = model_kwargs['replace_val']
+ # x_t = self.q_sample(replace_val, t - 1) if t > 0 else replace_val
+ # B, T = img.shape[:2]
+ # img = img.reshape(B, T, -1, 3)
+ # img[:, :, replace_idx] = x_t[:, :, replace_idx]
+ # out["sample"] = img = img.flatten(2)
+ # if 'frames_inp' in model_kwargs:
+ # x_t = self.q_sample(model_kwargs['frames_inp'], t - 1) \
+ # if t > 0 else model_kwargs['frames_inp']
+ # img = img * mask + x_t * (1 - mask)
+ # out['sample_merge'] = img
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+
+ # mask = model_kwargs['mask']
+ # if mask.shape != x_start.shape:
+ # mask = mask.expand_as(x_start)
+ # mask = mask.flatten(2)
+ #
+ # terms["mse"] = (target - model_output) ** 2
+ # terms["mse"] = terms["mse"].flatten(2)
+ # terms["mse"] = (terms["mse"] * mask).sum(-1) / mask.sum(-1)
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ # print(">>>", (target - model_output).abs().mean())
+
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+
+ This term can't be optimized, as it only depends on the encoder.
+
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
diff --git a/modules/commons/improved_diffusion/image_datasets.py b/modules/commons/improved_diffusion/image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e49d2394622e5b7ea988e4afe9fef117dedf6a9
--- /dev/null
+++ b/modules/commons/improved_diffusion/image_datasets.py
@@ -0,0 +1,106 @@
+from PIL import Image
+import blobfile as bf
+from mpi4py import MPI
+import numpy as np
+from torch.utils.data import DataLoader, Dataset
+
+
+def load_data(
+ *, data_dir, batch_size, image_size, class_cond=False, deterministic=False
+):
+ """
+ For a dataset, create a generator over (images, kwargs) pairs.
+
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
+ more keys, each of which map to a batched Tensor of their own.
+ The kwargs dict can be used for class labels, in which case the key is "y"
+ and the values are integer tensors of class labels.
+
+ :param data_dir: a dataset directory.
+ :param batch_size: the batch size of each returned pair.
+ :param image_size: the size to which images are resized.
+ :param class_cond: if True, include a "y" key in returned dicts for class
+ label. If classes are not available and this is true, an
+ exception will be raised.
+ :param deterministic: if True, yield results in a deterministic order.
+ """
+ if not data_dir:
+ raise ValueError("unspecified data directory")
+ all_files = _list_image_files_recursively(data_dir)
+ classes = None
+ if class_cond:
+ # Assume classes are the first part of the filename,
+ # before an underscore.
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
+ classes = [sorted_classes[x] for x in class_names]
+ dataset = ImageDataset(
+ image_size,
+ all_files,
+ classes=classes,
+ shard=MPI.COMM_WORLD.Get_rank(),
+ num_shards=MPI.COMM_WORLD.Get_size(),
+ )
+ if deterministic:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
+ )
+ else:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
+ )
+ while True:
+ yield from loader
+
+
+def _list_image_files_recursively(data_dir):
+ results = []
+ for entry in sorted(bf.listdir(data_dir)):
+ full_path = bf.join(data_dir, entry)
+ ext = entry.split(".")[-1]
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
+ results.append(full_path)
+ elif bf.isdir(full_path):
+ results.extend(_list_image_files_recursively(full_path))
+ return results
+
+
+class ImageDataset(Dataset):
+ def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1):
+ super().__init__()
+ self.resolution = resolution
+ self.local_images = image_paths[shard:][::num_shards]
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
+
+ def __len__(self):
+ return len(self.local_images)
+
+ def __getitem__(self, idx):
+ path = self.local_images[idx]
+ with bf.BlobFile(path, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ while min(*pil_image.size) >= 2 * self.resolution:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = self.resolution / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image.convert("RGB"))
+ crop_y = (arr.shape[0] - self.resolution) // 2
+ crop_x = (arr.shape[1] - self.resolution) // 2
+ arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
+ arr = arr.astype(np.float32) / 127.5 - 1
+
+ out_dict = {}
+ if self.local_classes is not None:
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
+ return np.transpose(arr, [2, 0, 1]), out_dict
diff --git a/modules/commons/improved_diffusion/logger.py b/modules/commons/improved_diffusion/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d856dcfea6b56a2ee8d37b286887430dbfac30
--- /dev/null
+++ b/modules/commons/improved_diffusion/logger.py
@@ -0,0 +1,495 @@
+"""
+Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
+https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
+"""
+
+import os
+import sys
+import shutil
+import os.path as osp
+import json
+import time
+import datetime
+import tempfile
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+
+DEBUG = 10
+INFO = 20
+WARN = 30
+ERROR = 40
+
+DISABLED = 50
+
+
+class KVWriter(object):
+ def writekvs(self, kvs):
+ raise NotImplementedError
+
+
+class SeqWriter(object):
+ def writeseq(self, seq):
+ raise NotImplementedError
+
+
+class HumanOutputFormat(KVWriter, SeqWriter):
+ def __init__(self, filename_or_file):
+ if isinstance(filename_or_file, str):
+ self.file = open(filename_or_file, "wt")
+ self.own_file = True
+ else:
+ assert hasattr(filename_or_file, "read"), (
+ "expected file or str, got %s" % filename_or_file
+ )
+ self.file = filename_or_file
+ self.own_file = False
+
+ def writekvs(self, kvs):
+ # Create strings for printing
+ key2str = {}
+ for (key, val) in sorted(kvs.items()):
+ if hasattr(val, "__float__"):
+ valstr = "%-8.3g" % val
+ else:
+ valstr = str(val)
+ key2str[self._truncate(key)] = self._truncate(valstr)
+
+ # Find max widths
+ if len(key2str) == 0:
+ print("WARNING: tried to write empty key-value dict")
+ return
+ else:
+ keywidth = max(map(len, key2str.keys()))
+ valwidth = max(map(len, key2str.values()))
+
+ # Write out the data
+ dashes = "-" * (keywidth + valwidth + 7)
+ lines = [dashes]
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
+ lines.append(
+ "| %s%s | %s%s |"
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
+ )
+ lines.append(dashes)
+ self.file.write("\n".join(lines) + "\n")
+
+ # Flush the output to the file
+ self.file.flush()
+
+ def _truncate(self, s):
+ maxlen = 30
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
+
+ def writeseq(self, seq):
+ seq = list(seq)
+ for (i, elem) in enumerate(seq):
+ self.file.write(elem)
+ if i < len(seq) - 1: # add space unless this is the last one
+ self.file.write(" ")
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ if self.own_file:
+ self.file.close()
+
+
+class JSONOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "wt")
+
+ def writekvs(self, kvs):
+ for k, v in sorted(kvs.items()):
+ if hasattr(v, "dtype"):
+ kvs[k] = float(v)
+ self.file.write(json.dumps(kvs) + "\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class CSVOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "w+t")
+ self.keys = []
+ self.sep = ","
+
+ def writekvs(self, kvs):
+ # Add our current row to the history
+ extra_keys = list(kvs.keys() - self.keys)
+ extra_keys.sort()
+ if extra_keys:
+ self.keys.extend(extra_keys)
+ self.file.seek(0)
+ lines = self.file.readlines()
+ self.file.seek(0)
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ self.file.write(k)
+ self.file.write("\n")
+ for line in lines[1:]:
+ self.file.write(line[:-1])
+ self.file.write(self.sep * len(extra_keys))
+ self.file.write("\n")
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ v = kvs.get(k)
+ if v is not None:
+ self.file.write(str(v))
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class TensorBoardOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
+ """
+
+ def __init__(self, dir):
+ os.makedirs(dir, exist_ok=True)
+ self.dir = dir
+ self.step = 1
+ prefix = "events"
+ path = osp.join(osp.abspath(dir), prefix)
+ import tensorflow as tf
+ from tensorflow.python import pywrap_tensorflow
+ from tensorflow.core.util import event_pb2
+ from tensorflow.python.util import compat
+
+ self.tf = tf
+ self.event_pb2 = event_pb2
+ self.pywrap_tensorflow = pywrap_tensorflow
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
+
+ def writekvs(self, kvs):
+ def summary_val(k, v):
+ kwargs = {"tag": k, "simple_value": float(v)}
+ return self.tf.Summary.Value(**kwargs)
+
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
+ event.step = (
+ self.step
+ ) # is there any reason why you'd want to specify the step?
+ self.writer.WriteEvent(event)
+ self.writer.Flush()
+ self.step += 1
+
+ def close(self):
+ if self.writer:
+ self.writer.Close()
+ self.writer = None
+
+
+def make_output_format(format, ev_dir, log_suffix=""):
+ os.makedirs(ev_dir, exist_ok=True)
+ if format == "stdout":
+ return HumanOutputFormat(sys.stdout)
+ elif format == "log":
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
+ elif format == "json":
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
+ elif format == "csv":
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
+ elif format == "tensorboard":
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
+ else:
+ raise ValueError("Unknown format specified: %s" % (format,))
+
+
+# ================================================================
+# API
+# ================================================================
+
+
+def logkv(key, val):
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ If called many times, last value will be used.
+ """
+ get_current().logkv(key, val)
+
+
+def logkv_mean(key, val):
+ """
+ The same as logkv(), but if called many times, values averaged.
+ """
+ get_current().logkv_mean(key, val)
+
+
+def logkvs(d):
+ """
+ Log a dictionary of key-value pairs
+ """
+ for (k, v) in d.items():
+ logkv(k, v)
+
+
+def dumpkvs():
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ return get_current().dumpkvs()
+
+
+def getkvs():
+ return get_current().name2val
+
+
+def log(*args, level=INFO):
+ """
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
+ """
+ get_current().log(*args, level=level)
+
+
+def debug(*args):
+ log(*args, level=DEBUG)
+
+
+def info(*args):
+ log(*args, level=INFO)
+
+
+def warn(*args):
+ log(*args, level=WARN)
+
+
+def error(*args):
+ log(*args, level=ERROR)
+
+
+def set_level(level):
+ """
+ Set logging threshold on current logger.
+ """
+ get_current().set_level(level)
+
+
+def set_comm(comm):
+ get_current().set_comm(comm)
+
+
+def get_dir():
+ """
+ Get directory that log files are being written to.
+ will be None if there is no output directory (i.e., if you didn't call start)
+ """
+ return get_current().get_dir()
+
+
+record_tabular = logkv
+dump_tabular = dumpkvs
+
+
+@contextmanager
+def profile_kv(scopename):
+ logkey = "wait_" + scopename
+ tstart = time.time()
+ try:
+ yield
+ finally:
+ get_current().name2val[logkey] += time.time() - tstart
+
+
+def profile(n):
+ """
+ Usage:
+ @profile("my_func")
+ def my_func(): code
+ """
+
+ def decorator_with_name(func):
+ def func_wrapper(*args, **kwargs):
+ with profile_kv(n):
+ return func(*args, **kwargs)
+
+ return func_wrapper
+
+ return decorator_with_name
+
+
+# ================================================================
+# Backend
+# ================================================================
+
+
+def get_current():
+ if Logger.CURRENT is None:
+ _configure_default_logger()
+
+ return Logger.CURRENT
+
+
+class Logger(object):
+ DEFAULT = None # A logger with no output files. (See right below class definition)
+ # So that you can still log to the terminal without setting up any output files
+ CURRENT = None # Current logger being used by the free functions above
+
+ def __init__(self, dir, output_formats, comm=None):
+ self.name2val = defaultdict(float) # values this iteration
+ self.name2cnt = defaultdict(int)
+ self.level = INFO
+ self.dir = dir
+ self.output_formats = output_formats
+ self.comm = comm
+
+ # Logging API, forwarded
+ # ----------------------------------------
+ def logkv(self, key, val):
+ self.name2val[key] = val
+
+ def logkv_mean(self, key, val):
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
+ self.name2cnt[key] = cnt + 1
+
+ def dumpkvs(self):
+ if self.comm is None:
+ d = self.name2val
+ else:
+ d = mpi_weighted_mean(
+ self.comm,
+ {
+ name: (val, self.name2cnt.get(name, 1))
+ for (name, val) in self.name2val.items()
+ },
+ )
+ if self.comm.rank != 0:
+ d["dummy"] = 1 # so we don't get a warning about empty dict
+ out = d.copy() # Return the dict for unit testing purposes
+ for fmt in self.output_formats:
+ if isinstance(fmt, KVWriter):
+ fmt.writekvs(d)
+ self.name2val.clear()
+ self.name2cnt.clear()
+ return out
+
+ def log(self, *args, level=INFO):
+ if self.level <= level:
+ self._do_log(args)
+
+ # Configuration
+ # ----------------------------------------
+ def set_level(self, level):
+ self.level = level
+
+ def set_comm(self, comm):
+ self.comm = comm
+
+ def get_dir(self):
+ return self.dir
+
+ def close(self):
+ for fmt in self.output_formats:
+ fmt.close()
+
+ # Misc
+ # ----------------------------------------
+ def _do_log(self, args):
+ for fmt in self.output_formats:
+ if isinstance(fmt, SeqWriter):
+ fmt.writeseq(map(str, args))
+
+
+def get_rank_without_mpi_import():
+ # check environment variables here instead of importing mpi4py
+ # to avoid calling MPI_Init() when this module is imported
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
+ if varname in os.environ:
+ return int(os.environ[varname])
+ return 0
+
+
+def mpi_weighted_mean(comm, local_name2valcount):
+ """
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
+ Perform a weighted average over dicts that are each on a different node
+ Input: local_name2valcount: dict mapping key -> (value, count)
+ Returns: key -> mean
+ """
+ all_name2valcount = comm.gather(local_name2valcount)
+ if comm.rank == 0:
+ name2sum = defaultdict(float)
+ name2count = defaultdict(float)
+ for n2vc in all_name2valcount:
+ for (name, (val, count)) in n2vc.items():
+ try:
+ val = float(val)
+ except ValueError:
+ if comm.rank == 0:
+ warnings.warn(
+ "WARNING: tried to compute mean on non-float {}={}".format(
+ name, val
+ )
+ )
+ else:
+ name2sum[name] += val * count
+ name2count[name] += count
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
+ else:
+ return {}
+
+
+def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
+ """
+ If comm is provided, average all numerical stats across that comm
+ """
+ if dir is None:
+ dir = os.getenv("OPENAI_LOGDIR")
+ if dir is None:
+ dir = osp.join(
+ tempfile.gettempdir(),
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
+ )
+ assert isinstance(dir, str)
+ dir = os.path.expanduser(dir)
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
+
+ rank = get_rank_without_mpi_import()
+ if rank > 0:
+ log_suffix = log_suffix + "-rank%03i" % rank
+
+ if format_strs is None:
+ if rank == 0:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
+ else:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
+ format_strs = filter(None, format_strs)
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
+
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
+ if output_formats:
+ log("Logging to %s" % dir)
+
+
+def _configure_default_logger():
+ configure()
+ Logger.DEFAULT = Logger.CURRENT
+
+
+def reset():
+ if Logger.CURRENT is not Logger.DEFAULT:
+ Logger.CURRENT.close()
+ Logger.CURRENT = Logger.DEFAULT
+ log("Reset logger")
+
+
+@contextmanager
+def scoped_configure(dir=None, format_strs=None, comm=None):
+ prevlogger = Logger.CURRENT
+ configure(dir=dir, format_strs=format_strs, comm=comm)
+ try:
+ yield
+ finally:
+ Logger.CURRENT.close()
+ Logger.CURRENT = prevlogger
+
diff --git a/modules/commons/improved_diffusion/losses.py b/modules/commons/improved_diffusion/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..251e42e4f36a31bb5e1aeda874b3a45d722000a2
--- /dev/null
+++ b/modules/commons/improved_diffusion/losses.py
@@ -0,0 +1,77 @@
+"""
+Helpers for various likelihood-based losses. These are ported from the original
+Ho et al. diffusion models codebase:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
+"""
+
+import numpy as np
+
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/modules/commons/improved_diffusion/nn.py b/modules/commons/improved_diffusion/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cd59c2324b003626b8cf4c7581effd334908d3
--- /dev/null
+++ b/modules/commons/improved_diffusion/nn.py
@@ -0,0 +1,170 @@
+"""
+Various utilities for neural networks.
+"""
+
+import math
+
+import torch as th
+import torch.nn as nn
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * th.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(th.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with th.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with th.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = th.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/modules/commons/improved_diffusion/resample.py b/modules/commons/improved_diffusion/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..c82eccdcd47c468d41e7cbe02de6a731f2c9bf81
--- /dev/null
+++ b/modules/commons/improved_diffusion/resample.py
@@ -0,0 +1,154 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/modules/commons/improved_diffusion/respace.py b/modules/commons/improved_diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..045d58df956e6ddb04216e972bffff47c59bf488
--- /dev/null
+++ b/modules/commons/improved_diffusion/respace.py
@@ -0,0 +1,122 @@
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/modules/commons/improved_diffusion/train_util.py b/modules/commons/improved_diffusion/train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1867604145736352dc51ab05b6caae8b541a6ebb
--- /dev/null
+++ b/modules/commons/improved_diffusion/train_util.py
@@ -0,0 +1,356 @@
+import copy
+import functools
+import os
+
+import blobfile as bf
+import numpy as np
+import torch as th
+import torch.distributed as dist
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+
+from . import dist_util, logger
+from .fp16_util import (
+ make_master_params,
+ master_params_to_model_params,
+ model_grads_to_master_grads,
+ unflatten_master_params,
+ zero_grad,
+)
+from .nn import update_ema
+from .resample import LossAwareSampler, UniformSampler
+
+# For ImageNet experiments, this was a good default value.
+# We found that the lg_loss_scale quickly climbed to
+# 20-21 within the first ~1K steps of training.
+INITIAL_LOG_LOSS_SCALE = 20.0
+
+
+class TrainLoop:
+ def __init__(
+ self,
+ *,
+ model,
+ diffusion,
+ data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ ):
+ self.model = model
+ self.diffusion = diffusion
+ self.data = data
+ self.batch_size = batch_size
+ self.microbatch = microbatch if microbatch > 0 else batch_size
+ self.lr = lr
+ self.ema_rate = (
+ [ema_rate]
+ if isinstance(ema_rate, float)
+ else [float(x) for x in ema_rate.split(",")]
+ )
+ self.log_interval = log_interval
+ self.save_interval = save_interval
+ self.resume_checkpoint = resume_checkpoint
+ self.use_fp16 = use_fp16
+ self.fp16_scale_growth = fp16_scale_growth
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
+ self.weight_decay = weight_decay
+ self.lr_anneal_steps = lr_anneal_steps
+
+ self.step = 0
+ self.resume_step = 0
+ self.global_batch = self.batch_size * dist.get_world_size()
+
+ self.model_params = list(self.model.parameters())
+ self.master_params = self.model_params
+ self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
+ self.sync_cuda = th.cuda.is_available()
+
+ self._load_and_sync_parameters()
+ if self.use_fp16:
+ self._setup_fp16()
+
+ self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
+ if self.resume_step:
+ self._load_optimizer_state()
+ # Model was resumed, either due to a restart or a checkpoint
+ # being specified at the command line.
+ self.ema_params = [
+ self._load_ema_parameters(rate) for rate in self.ema_rate
+ ]
+ else:
+ self.ema_params = [
+ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
+ ]
+
+ if th.cuda.is_available():
+ self.use_ddp = True
+ self.ddp_model = DDP(
+ self.model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ if dist.get_world_size() > 1:
+ logger.warn(
+ "Distributed training requires CUDA. "
+ "Gradients will not be synchronized properly!"
+ )
+ self.use_ddp = False
+ self.ddp_model = self.model
+
+ def _load_and_sync_parameters(self):
+ resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+
+ if resume_checkpoint:
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
+ if dist.get_rank() == 0:
+ logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
+ self.model.load_state_dict(
+ dist_util.load_state_dict(
+ resume_checkpoint, map_location=dist_util.dev()
+ )
+ )
+
+ dist_util.sync_params(self.model.parameters())
+
+ def _load_ema_parameters(self, rate):
+ ema_params = copy.deepcopy(self.master_params)
+
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
+ if ema_checkpoint:
+ if dist.get_rank() == 0:
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+ state_dict = dist_util.load_state_dict(
+ ema_checkpoint, map_location=dist_util.dev()
+ )
+ ema_params = self._state_dict_to_master_params(state_dict)
+
+ dist_util.sync_params(ema_params)
+ return ema_params
+
+ def _load_optimizer_state(self):
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+ opt_checkpoint = bf.join(
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
+ )
+ if bf.exists(opt_checkpoint):
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
+ state_dict = dist_util.load_state_dict(
+ opt_checkpoint, map_location=dist_util.dev()
+ )
+ self.opt.load_state_dict(state_dict)
+
+ def _setup_fp16(self):
+ self.master_params = make_master_params(self.model_params)
+ self.model.convert_to_fp16()
+
+ def run_loop(self):
+ while (
+ not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps
+ ):
+ batch, cond = next(self.data)
+ self.run_step(batch, cond)
+ if self.step % self.log_interval == 0:
+ logger.dumpkvs()
+ if self.step % self.save_interval == 0:
+ self.save()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
+ return
+ self.step += 1
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+
+ def run_step(self, batch, cond):
+ self.forward_backward(batch, cond)
+ if self.use_fp16:
+ self.optimize_fp16()
+ else:
+ self.optimize_normal()
+ self.log_step()
+
+ def forward_backward(self, batch, cond):
+ zero_grad(self.model_params)
+ for i in range(0, batch.shape[0], self.microbatch):
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
+ micro_cond = {
+ k: v[i : i + self.microbatch].to(dist_util.dev())
+ for k, v in cond.items()
+ }
+ last_batch = (i + self.microbatch) >= batch.shape[0]
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
+
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ micro,
+ t,
+ model_kwargs=micro_cond,
+ )
+
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ else:
+ with self.ddp_model.no_sync():
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach()
+ )
+
+ loss = (losses["loss"] * weights).mean()
+ log_loss_dict(
+ self.diffusion, t, {k: v * weights for k, v in losses.items()}
+ )
+ if self.use_fp16:
+ loss_scale = 2 ** self.lg_loss_scale
+ (loss * loss_scale).backward()
+ else:
+ loss.backward()
+
+ def optimize_fp16(self):
+ if any(not th.isfinite(p.grad).all() for p in self.model_params):
+ self.lg_loss_scale -= 1
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
+ return
+
+ model_grads_to_master_grads(self.model_params, self.master_params)
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
+ self._log_grad_norm()
+ self._anneal_lr()
+ self.opt.step()
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.master_params, rate=rate)
+ master_params_to_model_params(self.model_params, self.master_params)
+ self.lg_loss_scale += self.fp16_scale_growth
+
+ def optimize_normal(self):
+ self._log_grad_norm()
+ self._anneal_lr()
+ self.opt.step()
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.master_params, rate=rate)
+
+ def _log_grad_norm(self):
+ sqsum = 0.0
+ for p in self.master_params:
+ sqsum += (p.grad ** 2).sum().item()
+ logger.logkv_mean("grad_norm", np.sqrt(sqsum))
+
+ def _anneal_lr(self):
+ if not self.lr_anneal_steps:
+ return
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
+ lr = self.lr * (1 - frac_done)
+ for param_group in self.opt.param_groups:
+ param_group["lr"] = lr
+
+ def log_step(self):
+ logger.logkv("step", self.step + self.resume_step)
+ logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
+ if self.use_fp16:
+ logger.logkv("lg_loss_scale", self.lg_loss_scale)
+
+ def save(self):
+ def save_checkpoint(rate, params):
+ state_dict = self._master_params_to_state_dict(params)
+ if dist.get_rank() == 0:
+ logger.log(f"saving model {rate}...")
+ if not rate:
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
+ else:
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(0, self.master_params)
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+
+ if dist.get_rank() == 0:
+ with bf.BlobFile(
+ bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
+ "wb",
+ ) as f:
+ th.save(self.opt.state_dict(), f)
+
+ dist.barrier()
+
+ def _master_params_to_state_dict(self, master_params):
+ if self.use_fp16:
+ master_params = unflatten_master_params(
+ self.model.parameters(), master_params
+ )
+ state_dict = self.model.state_dict()
+ for i, (name, _value) in enumerate(self.model.named_parameters()):
+ assert name in state_dict
+ state_dict[name] = master_params[i]
+ return state_dict
+
+ def _state_dict_to_master_params(self, state_dict):
+ params = [state_dict[name] for name, _ in self.model.named_parameters()]
+ if self.use_fp16:
+ return make_master_params(params)
+ else:
+ return params
+
+
+def parse_resume_step_from_filename(filename):
+ """
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
+ checkpoint's number of steps.
+ """
+ split = filename.split("model")
+ if len(split) < 2:
+ return 0
+ split1 = split[-1].split(".")[0]
+ try:
+ return int(split1)
+ except ValueError:
+ return 0
+
+
+def get_blob_logdir():
+ return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
+
+
+def find_resume_checkpoint():
+ # On your infrastructure, you may want to override this to automatically
+ # discover the latest checkpoint on your blob storage, etc.
+ return None
+
+
+def find_ema_checkpoint(main_checkpoint, step, rate):
+ if main_checkpoint is None:
+ return None
+ filename = f"ema_{rate}_{(step):06d}.pt"
+ path = bf.join(bf.dirname(main_checkpoint), filename)
+ if bf.exists(path):
+ return path
+ return None
+
+
+def log_loss_dict(diffusion, ts, losses):
+ for key, values in losses.items():
+ logger.logkv_mean(key, values.mean().item())
+ # Log the quantiles (four quartiles, in particular).
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
diff --git a/modules/commons/layers.py b/modules/commons/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e1c75876050fa05a768a5ae0467fdfc05bb006
--- /dev/null
+++ b/modules/commons/layers.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
diff --git a/modules/commons/normalizing_flow/glow_modules.py b/modules/commons/normalizing_flow/glow_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c589af0f2eba2b154317912f9ad01a4163b3fd6a
--- /dev/null
+++ b/modules/commons/normalizing_flow/glow_modules.py
@@ -0,0 +1,362 @@
+import scipy
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+from modules.commons.wavenet import WN
+from modules.tts.glow import utils
+
+
+class ActNorm(nn.Module):
+ def __init__(self, channels, ddi=False, **kwargs):
+ super().__init__()
+ self.channels = channels
+ self.initialized = not ddi
+
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ if x_mask is None:
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
+ x_len = torch.sum(x_mask, [1, 2])
+ if not self.initialized:
+ self.initialize(x, x_mask)
+ self.initialized = True
+
+ if reverse:
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
+ logdet = torch.sum(-self.logs) * x_len
+ else:
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
+ logdet = torch.sum(self.logs) * x_len # [b]
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+ def set_ddi(self, ddi):
+ self.initialized = not ddi
+
+ def initialize(self, x, x_mask):
+ with torch.no_grad():
+ denom = torch.sum(x_mask, [0, 2])
+ m = torch.sum(x * x_mask, [0, 2]) / denom
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
+ v = m_sq - (m ** 2)
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
+
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
+
+ self.bias.data.copy_(bias_init)
+ self.logs.data.copy_(logs_init)
+
+
+class InvConvNear(nn.Module):
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
+ super().__init__()
+ assert (n_split % 2 == 0)
+ self.channels = channels
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.no_jacobian = no_jacobian
+
+ w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
+ if torch.det(w_init) < 0:
+ w_init[:, 0] = -1 * w_init[:, 0]
+ self.lu = lu
+ if lu:
+ # LU decomposition can slightly speed up the inverse
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
+ eye = np.eye(*w_init.shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
+ self.register_buffer('eye', torch.Tensor(eye))
+ else:
+ self.weight = nn.Parameter(w_init)
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ b, c, t = x.size()
+ assert (c % self.n_split == 0)
+ if x_mask is None:
+ x_mask = 1
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
+
+ if self.lu:
+ self.weight, log_s = self._get_weight()
+ logdet = log_s.sum()
+ logdet = logdet * (c / self.n_split) * x_len
+ else:
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
+
+ if reverse:
+ if hasattr(self, "weight_inv"):
+ weight = self.weight_inv
+ else:
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
+ logdet = -logdet
+ else:
+ weight = self.weight
+ if self.no_jacobian:
+ logdet = 0
+
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
+ z = F.conv2d(x, weight)
+
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
+ return z, logdet
+
+ def _get_weight(self):
+ l, log_s, u = self.l, self.log_s, self.u
+ l = l * self.l_mask + self.eye
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
+ weight = torch.matmul(self.p, torch.matmul(l, u))
+ return weight, log_s
+
+ def store_inverse(self):
+ weight, _ = self._get_weight()
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
+
+
+class InvConv(nn.Module):
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
+ super().__init__()
+ w_shape = [channels, channels]
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
+ LU_decomposed = lu
+ if not LU_decomposed:
+ # Sample a random orthogonal matrix:
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
+ else:
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
+ eye = np.eye(*w_shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
+ self.l_mask = torch.Tensor(l_mask)
+ self.eye = torch.Tensor(eye)
+ self.w_shape = w_shape
+ self.LU = LU_decomposed
+ self.weight = None
+
+ def get_weight(self, device, reverse):
+ w_shape = self.w_shape
+ self.p = self.p.to(device)
+ self.sign_s = self.sign_s.to(device)
+ self.l_mask = self.l_mask.to(device)
+ self.eye = self.eye.to(device)
+ l = self.l * self.l_mask + self.eye
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
+ dlogdet = self.log_s.sum()
+ if not reverse:
+ w = torch.matmul(self.p, torch.matmul(l, u))
+ else:
+ l = torch.inverse(l.double()).float()
+ u = torch.inverse(u.double()).float()
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ """
+ log-det = log|abs(|W|)| * pixels
+ """
+ b, c, t = x.size()
+ if x_mask is None:
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+ logdet = 0
+ if not reverse:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet + dlogdet * x_len
+ return z, logdet
+ else:
+ if self.weight is None:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ else:
+ weight, dlogdet = self.weight, self.dlogdet
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet - dlogdet * x_len
+ return z, logdet
+
+ def store_inverse(self):
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
+
+
+class CouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False, wn=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ start = torch.nn.utils.weight_norm(start)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout)
+ if wn is not None:
+ self.wn.in_layers = wn.in_layers
+ self.wn.res_skip_layers = wn.res_skip_layers
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.wn(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ self.wn.remove_weight_norm()
+
+
+class Glow(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_blocks,
+ n_layers,
+ p_dropout=0.,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=0,
+ inv_conv_type='near',
+ share_cond_layers=False,
+ share_wn_layers=0,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_blocks = n_blocks
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.sigmoid_scale = sigmoid_scale
+ self.gin_channels = gin_channels
+ self.share_cond_layers = share_cond_layers
+ if gin_channels != 0 and share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+ wn = None
+ self.flows = nn.ModuleList()
+ for b in range(n_blocks):
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
+ if inv_conv_type == 'near':
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
+ if inv_conv_type == 'invconv':
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
+ if share_wn_layers > 0:
+ if b % share_wn_layers == 0:
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
+ p_dropout, share_cond_layers)
+ self.flows.append(
+ CouplingBlock(
+ in_channels * n_sqz,
+ hidden_channels,
+ kernel_size=kernel_size,
+ dilation_rate=dilation_rate,
+ n_layers=n_layers,
+ gin_channels=gin_channels * n_sqz,
+ p_dropout=p_dropout,
+ sigmoid_scale=sigmoid_scale,
+ wn=wn
+ ))
+
+ def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
+ logdet_tot = 0
+ if not reverse:
+ flows = self.flows
+ else:
+ flows = reversed(self.flows)
+ if return_hiddens:
+ hs = []
+ if self.n_sqz > 1:
+ x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
+ if g is not None:
+ g, _ = utils.squeeze(g, x_mask, self.n_sqz)
+ x_mask = x_mask_
+ if self.share_cond_layers and g is not None:
+ g = self.cond_layer(g)
+ for f in flows:
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
+ if return_hiddens:
+ hs.append(x)
+ logdet_tot += logdet
+ if self.n_sqz > 1:
+ x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
+ if return_hiddens:
+ return x, logdet_tot, hs
+ return x, logdet_tot
+
+ def store_inverse(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+ for f in self.flows:
+ f.store_inverse()
diff --git a/modules/commons/normalizing_flow/res_flow.py b/modules/commons/normalizing_flow/res_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..428fb7da9e3becb0d11cdf239fff410c86028d95
--- /dev/null
+++ b/modules/commons/normalizing_flow/res_flow.py
@@ -0,0 +1,61 @@
+import torch
+from torch import nn
+from modules.commons.conv import ConditionalConvBlocks
+from modules.commons.wavenet import WN
+
+
+class FlipLayer(nn.Module):
+ def forward(self, x, *args, **kwargs):
+ x = torch.flip(x, [1])
+ return x
+
+
+class CouplingLayer(nn.Module):
+ def __init__(self, c_in, hidden_size, kernel_size, n_layers, p_dropout=0, c_in_g=0, nn_type='wn'):
+ super().__init__()
+ self.channels = c_in
+ self.hidden_size = hidden_size
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.c_half = c_in // 2
+
+ self.pre = nn.Conv1d(self.c_half, hidden_size, 1)
+ if nn_type == 'wn':
+ self.enc = WN(hidden_size, kernel_size, 1, n_layers, p_dropout=p_dropout,
+ c_cond=c_in_g)
+ elif nn_type == 'conv':
+ self.enc = ConditionalConvBlocks(
+ hidden_size, c_in_g, hidden_size, None, kernel_size,
+ layers_in_block=1, is_BTC=False, num_layers=n_layers)
+ self.post = nn.Conv1d(hidden_size, self.c_half, 1)
+
+ def forward(self, x, nonpadding, cond=None, reverse=False):
+ x0, x1 = x[:, :self.c_half], x[:, self.c_half:]
+ x_ = self.pre(x0) * nonpadding
+ x_ = self.enc(x_, nonpadding=nonpadding, cond=cond)
+ m = self.post(x_)
+ x1 = m + x1 if not reverse else x1 - m
+ x = torch.cat([x0, x1], 1)
+ return x * nonpadding
+
+
+class ResFlow(nn.Module):
+ def __init__(self,
+ c_in,
+ hidden_size,
+ kernel_size,
+ n_flow_layers,
+ n_flow_steps=4,
+ c_cond=0,
+ nn_type='wn'):
+ super().__init__()
+ self.flows = nn.ModuleList()
+ for i in range(n_flow_steps):
+ self.flows.append(
+ CouplingLayer(c_in, hidden_size, kernel_size, n_flow_layers, c_in_g=c_cond, nn_type=nn_type))
+ self.flows.append(FlipLayer())
+
+ def forward(self, x, nonpadding, cond=None, reverse=False):
+ for flow in (self.flows if not reverse else reversed(self.flows)):
+ x = flow(x, nonpadding, cond=cond, reverse=reverse)
+ return x
diff --git a/modules/commons/normalizing_flow/utils.py b/modules/commons/normalizing_flow/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb56ec514bff822ba1a19a6474207ed82492410
--- /dev/null
+++ b/modules/commons/normalizing_flow/utils.py
@@ -0,0 +1,29 @@
+import torch
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ t = (t // n_sqz) * n_sqz
+ x = x[:, :, :t]
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
+ else:
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+ else:
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_unsqz * x_mask, x_mask
diff --git a/modules/commons/rel_transformer.py b/modules/commons/rel_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd41b301a98609391d1a18b118d1f1b3e538af1d
--- /dev/null
+++ b/modules/commons/rel_transformer.py
@@ -0,0 +1,389 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from modules.commons.layers import Embedding
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class Encoder(nn.Module):
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.pre_ln = pre_ln
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
+ p_dropout=p_dropout, block_length=block_length))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+ if pre_ln:
+ self.last_ln = LayerNorm(hidden_channels)
+
+ def forward(self, x, x_mask, attn_mask=1):
+ if isinstance(attn_mask, torch.Tensor):
+ attn_mask = attn_mask[:, None]
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
+ for i in range(self.n_layers):
+ x = x * x_mask
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_1[i](x)
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_1[i](x)
+
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ if self.pre_ln:
+ x = self.last_ln(x)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
+ block_length=None, proximal_bias=False, proximal_init=False):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.p_dropout = p_dropout
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels ** -0.5
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ if proximal_init:
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+ if self.window_size is not None:
+ assert t_s == t_t, "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
+ scores_local = rel_logits / math.sqrt(self.k_channels)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, -1])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ return x * x_mask
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class RelTransformerEncoder(nn.Module):
+ def __init__(self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout=0.0,
+ window_size=4,
+ block_length=None,
+ in_channels=None,
+ prenet=True,
+ pre_ln=True,
+ ):
+
+ super().__init__()
+
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.prenet = prenet
+ if n_vocab > 0:
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
+
+ if prenet:
+ if in_channels is None:
+ in_channels = hidden_channels
+ self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
+ kernel_size=5, n_layers=3, p_dropout=0)
+ if in_channels is not None and in_channels != hidden_channels:
+ self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.encoder = Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ window_size=window_size,
+ block_length=block_length,
+ pre_ln=pre_ln,
+ )
+
+ def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
+ if self.n_vocab > 0:
+ x_lengths = (x > 0).long().sum(-1)
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ else:
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
+ x = x + other_embeds
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ if self.prenet:
+ x = self.pre(x, x_mask)
+ self.prenet_out = x.transpose(1, 2)
+ if hasattr(self, 'encoder_inp_proj'):
+ x = self.encoder_inp_proj(x) * x_mask
+ x = self.encoder(x, x_mask, attn_mask)
+ return x.transpose(1, 2)
diff --git a/modules/commons/rnn.py b/modules/commons/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..205c2c76b8fda2de920bc59228a5eec0a20119a9
--- /dev/null
+++ b/modules/commons/rnn.py
@@ -0,0 +1,261 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PreNet(nn.Module):
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
+ super().__init__()
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
+ self.p = dropout
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ x = self.fc2(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ return x
+
+
+class HighwayNetwork(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.W1 = nn.Linear(size, size)
+ self.W2 = nn.Linear(size, size)
+ self.W1.bias.data.fill_(0.)
+
+ def forward(self, x):
+ x1 = self.W1(x)
+ x2 = self.W2(x)
+ g = torch.sigmoid(x2)
+ y = g * F.relu(x1) + (1. - g) * x
+ return y
+
+
+class BatchNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
+ super().__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
+ self.bnorm = nn.BatchNorm1d(out_channels)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(x) if self.relu is True else x
+ return self.bnorm(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+class CBHG(nn.Module):
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
+ super().__init__()
+
+ # List of all rnns to call `flatten_parameters()` on
+ self._to_flatten = []
+
+ self.bank_kernels = [i for i in range(1, K + 1)]
+ self.conv1d_bank = nn.ModuleList()
+ for k in self.bank_kernels:
+ conv = BatchNormConv(in_channels, channels, k)
+ self.conv1d_bank.append(conv)
+
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
+
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
+
+ # Fix the highway input if necessary
+ if proj_channels[-1] != channels:
+ self.highway_mismatch = True
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
+ else:
+ self.highway_mismatch = False
+
+ self.highways = nn.ModuleList()
+ for i in range(num_highways):
+ hn = HighwayNetwork(channels)
+ self.highways.append(hn)
+
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
+ self._to_flatten.append(self.rnn)
+
+ # Avoid fragmentation of RNN parameters and associated warning
+ self._flatten_parameters()
+
+ def forward(self, x):
+ # Although we `_flatten_parameters()` on init, when using DataParallel
+ # the model gets replicated, making it no longer guaranteed that the
+ # weights are contiguous in GPU memory. Hence, we must call it again
+ self._flatten_parameters()
+
+ # Save these for later
+ residual = x
+ seq_len = x.size(-1)
+ conv_bank = []
+
+ # Convolution Bank
+ for conv in self.conv1d_bank:
+ c = conv(x) # Convolution
+ conv_bank.append(c[:, :, :seq_len])
+
+ # Stack along the channel axis
+ conv_bank = torch.cat(conv_bank, dim=1)
+
+ # dump the last padding to fit residual
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
+
+ # Conv1d projections
+ x = self.conv_project1(x)
+ x = self.conv_project2(x)
+
+ # Residual Connect
+ x = x + residual
+
+ # Through the highways
+ x = x.transpose(1, 2)
+ if self.highway_mismatch is True:
+ x = self.pre_highway(x)
+ for h in self.highways:
+ x = h(x)
+
+ # And then the RNN
+ x, _ = self.rnn(x)
+ return x
+
+ def _flatten_parameters(self):
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
+ to improve efficiency and avoid PyTorch yelling at us."""
+ [m.flatten_parameters() for m in self._to_flatten]
+
+
+class TacotronEncoder(nn.Module):
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
+ super().__init__()
+ self.embedding = nn.Embedding(num_chars, embed_dims)
+ self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
+ proj_channels=[cbhg_channels, cbhg_channels],
+ num_highways=num_highways)
+ self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.pre_net(x)
+ x.transpose_(1, 2)
+ x = self.cbhg(x)
+ x = self.proj_out(x)
+ return x
+
+
+class RNNEncoder(nn.Module):
+ def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
+ super(RNNEncoder, self).__init__()
+ self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
+ convolutions = []
+ for _ in range(n_convolutions):
+ conv_layer = nn.Sequential(
+ ConvNorm(embedding_dim,
+ embedding_dim,
+ kernel_size=kernel_size, stride=1,
+ padding=int((kernel_size - 1) / 2),
+ dilation=1, w_init_gain='relu'),
+ nn.BatchNorm1d(embedding_dim))
+ convolutions.append(conv_layer)
+ self.convolutions = nn.ModuleList(convolutions)
+
+ self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
+ batch_first=True, bidirectional=True)
+
+ def forward(self, x):
+ input_lengths = (x > 0).sum(-1)
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.embedding(x)
+ x = x.transpose(1, 2) # [B, H, T]
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
+ x = x.transpose(1, 2) # [B, T, H]
+
+ # pytorch tensor are not reversible, hence the conversion
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ return outputs
+
+
+class DecoderRNN(torch.nn.Module):
+ def __init__(self, hidden_size, decoder_rnn_dim, dropout):
+ super(DecoderRNN, self).__init__()
+ self.in_conv1d = nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ torch.nn.ReLU(),
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ )
+ self.ln = nn.LayerNorm(hidden_size)
+ if decoder_rnn_dim == 0:
+ decoder_rnn_dim = hidden_size * 2
+ self.rnn = torch.nn.LSTM(
+ input_size=hidden_size,
+ hidden_size=decoder_rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout
+ )
+ self.rnn.flatten_parameters()
+ self.conv1d = torch.nn.Conv1d(
+ in_channels=decoder_rnn_dim * 2,
+ out_channels=hidden_size,
+ kernel_size=3,
+ padding=1,
+ )
+
+ def forward(self, x):
+ input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
+ input_lengths = input_masks.sum([-1, -2])
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
+ x = self.ln(x)
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+ self.rnn.flatten_parameters()
+ x, _ = self.rnn(x) # [B, T, C]
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
+ x = x * input_masks
+ pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
+ pre_mel = pre_mel * input_masks
+ return pre_mel
diff --git a/modules/commons/rot_transformer.py b/modules/commons/rot_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17c488042b54a70f0b897f4efc488dfbce3b3b3
--- /dev/null
+++ b/modules/commons/rot_transformer.py
@@ -0,0 +1,635 @@
+import math
+import torch
+from typing import Optional, Tuple
+from torch import nn
+from torch.nn import Parameter, Linear
+from torch.cuda.amp import autocast
+from modules.commons.layers import LayerNorm, Embedding
+from modules.commons.transformer import TransformerFFNLayer, MultiheadAttention
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+DEFAULT_MAX_SOURCE_POSITIONS = 3000
+DEFAULT_MAX_TARGET_POSITIONS = 3000
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class RotaryEmbeddings(nn.Module):
+ cos: torch.Tensor
+ sin: torch.Tensor
+ theta: torch.Tensor
+
+ def __init__(
+ self,
+ width: int,
+ *,
+ seq_len: int = 4000,
+ base: int = 10000,
+ device: Optional[torch.device] = None,
+ ):
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
+ will be precomputed for up to 'seq _len' positions. The embedding
+ will be recomputed when a longer sequence is found in the input.
+
+ :param width:
+ Rotary embedding dimensionality, must be even.
+ :param seq_len:
+ Number of positons to initially precompute.
+ :param base:
+ The base used for Θ_i, determines the cycle length of the
+ embeddings.
+ :param device: Device on which the module is to be initialized.
+ """
+ super().__init__()
+
+ if width % 2:
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
+
+ # Ignore allocations on the meta device as we don't persist our buffer,
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
+ if device is not None and device.type == "meta":
+ device = None
+ # Θ_i = 10000^(-2(i-1)/d)
+ theta = torch.pow(
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
+ )
+ self.register_buffer("theta", theta, persistent=False)
+
+ self._create_rotary_embed(width=width, length=seq_len)
+
+ def _create_rotary_embed(self, *, width: int, length: int):
+ # mΘ
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
+ m_theta = position * self.theta.unsqueeze(0)
+
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
+ # is changed for compatibility with most common implementations.
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
+
+ re_cos = m_theta.cos().view([length, width])
+ re_sin = m_theta.sin().view([length, width])
+
+ self.register_buffer("cos", re_cos, persistent=False)
+ self.register_buffer("sin", re_sin, persistent=False)
+
+ def _rotate(self, input: torch.Tensor):
+ """Rotate the input tensor by half of its innermost width.
+
+ input (Tensor): array to rotate.
+ RETURNS (Tensor): rotated array.
+
+ Shapes:
+ input - (..., width)
+ output - (..., width)
+ """
+ half_idx = input.shape[-1] // 2
+ input_1 = -input[..., half_idx:]
+ input_2 = input[..., :half_idx]
+ return torch.cat([input_1, input_2], dim=-1)
+
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
+ """
+ Apply rotary embeddings to an array.
+
+ :param input: Array to apply the rotary embeddings to.
+ :param positions: positions of the inputs. If no positions are
+ provided, they are assumed to be [0, seq_len).
+ :return: Array with the rotary embeddings applied.
+
+ Shapes:
+ input - (batch_size, num_heads, seq_len, width_per_head)
+ positions - (batch_size, seq_len)
+ output - (batch_size, num_heads, seq_len, width_per_head)
+ """
+ batch_size, _, seq_len, width = input.shape
+
+ if positions is None:
+ # Fastpath: positions from [0..seq_len), avoid indexing.
+ if self.cos.size(-2) < seq_len:
+ self._create_rotary_embed(width=width, length=seq_len)
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
+ else:
+ max_len = int(positions.max()) + 1
+ if self.cos.size(-2) < max_len:
+ self._create_rotary_embed(width=width, length=max_len)
+
+ # Flatten positions to index cos/sin arrays, then unflatten.
+ #
+ # Example shapes:
+ #
+ # positions_flat - (batch_size * seq_len)
+ # self.cos - (max_len, width)
+ # rot_cos - (batch_size, seq_len, width)
+ positions_flat = positions.view(-1)
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
+
+ # Eq 34 with ordering changed for compatibility.
+ return rot_cos * input + rot_sin * self._rotate(input)
+
+
+class RotMultiheadAttention(MultiheadAttention):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
+ encoder_decoder_attention=encoder_decoder_attention)
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q = q * self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ if incremental_state is not None:
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ else:
+ key_pos = spk_pos_ids_flat
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+
+class RotMultiheadAttention2(MultiheadAttention):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
+ encoder_decoder_attention=encoder_decoder_attention)
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_logits = None
+ attn_weights = None
+ return attn, (attn_weights, attn_logits)
+
+
+class RotDecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = RotMultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ spk_pos_ids_flat=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+
+ x, (attn_weights, _) = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask,
+ spk_pos_ids_flat=spk_pos_ids_flat
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+ return x, attn_weights
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class RotDecSALayer2(RotDecSALayer):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
+ ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
+ post_ln)
+ self.self_attn = RotMultiheadAttention2(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+
+
+class RotTransformerDecoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
+ op_version=1):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if op_version == 1:
+ self.op = RotDecSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+ else:
+ self.op = RotDecSALayer2(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
diff --git a/modules/commons/taming_tfm_modules.py b/modules/commons/taming_tfm_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..79418633fbf06fac1afaa2d794a9ef2af9bdb7b3
--- /dev/null
+++ b/modules/commons/taming_tfm_modules.py
@@ -0,0 +1,366 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class Normalize(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+ self.proj = nn.Linear(channels, channels)
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ x = self.proj(x)
+ return x.transpose(1, 2)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x):
+ if self.with_conv:
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.conv2 = torch.nn.Conv1d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, _, x_mask):
+ x = x * x_mask
+ h = x
+ h = self.norm1(h) * x_mask
+ h = nonlinearity(h) * x_mask
+ h = self.conv1(h) * x_mask
+
+ h = self.norm2(h) * x_mask
+ h = nonlinearity(h) * x_mask
+ h = self.conv2(h) * x_mask
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x) * x_mask
+ else:
+ x = self.nin_shortcut(x) * x_mask
+
+ return (x + h) * x_mask
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, x_mask):
+ h_ = x * x_mask
+ h_ = self.norm(h_) * x_mask
+ q = self.q(h_) * x_mask
+ k = self.k(h_) * x_mask
+ v = self.v(h_) * x_mask
+
+ # compute attention
+ b, c, h = q.shape
+ w = 1
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = w_ + ((1 - x_mask) * -1e8) + ((1 - x_mask) * -1e8).transpose(1, 2)
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h)
+
+ h_ = self.proj_out(h_) * x_mask
+
+ return (x + h_) * x_mask
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
+ resamp_with_conv=False, in_channels):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv1d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv1d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, x_mask):
+ if x_mask is None:
+ x_mask = torch.ones_like(x_mask[:, :, :1])
+ x = x.permute(0, 2, 1)
+ x_mask = x_mask.permute(0, 2, 1)
+
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x) * x_mask]
+ for i_level in range(self.num_resolutions):
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb, x_mask_) * x_mask_
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h, x_mask_) * x_mask_
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]) * x_mask_[:, :, ::2])
+
+ x_mask_ = x_mask[:, :, ::2 ** (self.num_resolutions - 1)]
+ # middle
+ h = hs[-1] * x_mask_
+ h = self.mid.block_1(h, temb, x_mask_) * x_mask_
+ h = self.mid.attn_1(h, x_mask_) * x_mask_
+ h = self.mid.block_2(h, temb, x_mask_) * x_mask_
+
+ # end
+ h = self.norm_out(h) * x_mask_
+ h = nonlinearity(h) * x_mask_
+ h = self.conv_out(h) * x_mask_
+ h = h.permute(0, 2, 1)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
+ resamp_with_conv=True, in_channels, give_pre_end=False):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv1d(in_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv1d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z, x_mask):
+ if x_mask is None:
+ x_mask = torch.ones_like(z[:, :, :1]).repeat(1, 8, 1)
+ z = z.permute(0, 2, 1)
+ x_mask = x_mask.permute(0, 2, 1)
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ i_level = self.num_resolutions - 1
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ h = self.mid.block_1(h, temb, x_mask_)
+ h = self.mid.attn_1(h, x_mask_)
+ h = self.mid.block_2(h, temb, x_mask_)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, x_mask_)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, x_mask_)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h) * x_mask
+ h = h.permute(0, 2, 1)
+ return h
diff --git a/modules/commons/transformer.py b/modules/commons/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e09edfb2a124f7cc8913254b167fefec4f5b96
--- /dev/null
+++ b/modules/commons/transformer.py
@@ -0,0 +1,752 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter, Linear
+from modules.commons.layers import LayerNorm, Embedding
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+DEFAULT_MAX_SOURCE_POSITIONS = 3000
+DEFAULT_MAX_TARGET_POSITIONS = 3000
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q = q * self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
+ if static_kv:
+ key_padding_mask = prev_key_padding_mask
+ else:
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_key_padding_mask'] = key_padding_mask
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
+ ffn_hidden_size=1024):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+ self.layer_norm2 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ self.layer_norm3 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ attn_logits = None
+ if encoder_out is not None or attn_out is not None:
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ elif attn_out is not None:
+ x = self.encoder_attn.in_proj_v(attn_out)
+ if encoder_out is not None or attn_out is not None:
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm3(x)
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = DecSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
+ use_pos_embed_alpha=True, ffn_hidden_size=1024):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
+ ffn_hidden_size=ffn_hidden_size)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+
+class FastSpeechEncoder(FFTBlocks):
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
+ dropout=0.0, num_heads=2, ffn_hidden_size=1024):
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
+ use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+ self.padding_idx = 0
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
+ x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
+ if self.num_layers > 0:
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
+ return x
+
+ def forward_embedding(self, txt_tokens):
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ if self.use_pos_embed:
+ positions = self.embed_positions(txt_tokens)
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ return x
diff --git a/modules/commons/unet1d.py b/modules/commons/unet1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8a9bc82c22058bcc6d9c2ea59868b35c7fc2d5
--- /dev/null
+++ b/modules/commons/unet1d.py
@@ -0,0 +1,202 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+
+class UNet1d(nn.Module):
+
+ def __init__(self, in_channels=3, out_channels=1, init_features=128, multi=None):
+ super(UNet1d, self).__init__()
+ if multi is None:
+ multi = [1, 2, 2, 4]
+ features = init_features
+ self.encoder1 = UNet1d._block(in_channels, features * multi[0], name="enc1")
+ self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder2 = UNet1d._block(features * multi[0], features * multi[1], name="enc2")
+ self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder3 = UNet1d._block(features * multi[1], features * multi[2], name="enc3")
+ self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder4 = UNet1d._block(features * multi[2], features * multi[3], name="enc4")
+ self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)
+
+ self.bottleneck = UNet1d._block(features * multi[3], features * multi[3], name="bottleneck")
+
+ self.upconv4 = nn.ConvTranspose1d(
+ features * multi[3], features * multi[3], kernel_size=2, stride=2
+ )
+ self.decoder4 = UNet1d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
+ self.upconv3 = nn.ConvTranspose1d(
+ features * multi[3], features * multi[2], kernel_size=2, stride=2
+ )
+ self.decoder3 = UNet1d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
+ self.upconv2 = nn.ConvTranspose1d(
+ features * multi[2], features * multi[1], kernel_size=2, stride=2
+ )
+ self.decoder2 = UNet1d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
+ self.upconv1 = nn.ConvTranspose1d(
+ features * multi[1], features * multi[0], kernel_size=2, stride=2
+ )
+ self.decoder1 = UNet1d._block(features * multi[0] * 2, features * multi[0], name="dec1")
+
+ self.conv = nn.Conv1d(
+ in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
+ )
+
+ def forward(self, x, nonpadding=None):
+ if nonpadding is None:
+ nonpadding = torch.ones_like(x)[:, :, :1]
+ enc1 = self.encoder1(x.transpose(1, 2)) * nonpadding.transpose(1, 2)
+ enc2 = self.encoder2(self.pool1(enc1))
+ enc3 = self.encoder3(self.pool2(enc2))
+ enc4 = self.encoder4(self.pool3(enc3))
+
+ bottleneck = self.bottleneck(self.pool4(enc4))
+
+ dec4 = self.upconv4(bottleneck)
+ dec4 = torch.cat((dec4, enc4), dim=1)
+ dec4 = self.decoder4(dec4)
+ dec3 = self.upconv3(dec4)
+ dec3 = torch.cat((dec3, enc3), dim=1)
+ dec3 = self.decoder3(dec3)
+ dec2 = self.upconv2(dec3)
+ dec2 = torch.cat((dec2, enc2), dim=1)
+ dec2 = self.decoder2(dec2)
+ dec1 = self.upconv1(dec2)
+ dec1 = torch.cat((dec1, enc1), dim=1)
+ dec1 = self.decoder1(dec1)
+ return self.conv(dec1).transpose(1, 2) * nonpadding
+
+ @staticmethod
+ def _block(in_channels, features, name):
+ return nn.Sequential(
+ OrderedDict(
+ [
+ (
+ name + "conv1",
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=features,
+ kernel_size=5,
+ padding=2,
+ bias=False,
+ ),
+ ),
+ (name + "norm1", nn.GroupNorm(4, features)),
+ (name + "tanh1", nn.Tanh()),
+ (
+ name + "conv2",
+ nn.Conv1d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=5,
+ padding=2,
+ bias=False,
+ ),
+ ),
+ (name + "norm2", nn.GroupNorm(4, features)),
+ (name + "tanh2", nn.Tanh()),
+ ]
+ )
+ )
+
+
+class UNet2d(nn.Module):
+ def __init__(self, in_channels=3, out_channels=1, init_features=32, multi=None):
+ super(UNet2d, self).__init__()
+
+ features = init_features
+ self.encoder1 = UNet2d._block(in_channels, features * multi[0], name="enc1")
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder2 = UNet2d._block(features * multi[0], features * multi[1], name="enc2")
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder3 = UNet2d._block(features * multi[1], features * multi[2], name="enc3")
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder4 = UNet2d._block(features * multi[2], features * multi[3], name="enc4")
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ self.bottleneck = UNet2d._block(features * multi[3], features * multi[3], name="bottleneck")
+
+ self.upconv4 = nn.ConvTranspose2d(
+ features * multi[3], features * multi[3], kernel_size=2, stride=2
+ )
+ self.decoder4 = UNet2d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
+ self.upconv3 = nn.ConvTranspose2d(
+ features * multi[3], features * multi[2], kernel_size=2, stride=2
+ )
+ self.decoder3 = UNet2d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
+ self.upconv2 = nn.ConvTranspose2d(
+ features * multi[2], features * multi[1], kernel_size=2, stride=2
+ )
+ self.decoder2 = UNet2d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
+ self.upconv1 = nn.ConvTranspose2d(
+ features * multi[1], features * multi[0], kernel_size=2, stride=2
+ )
+ self.decoder1 = UNet2d._block(features * multi[0] * 2, features * multi[0], name="dec1")
+
+ self.conv = nn.Conv2d(
+ in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
+ )
+
+ def forward(self, x):
+ enc1 = self.encoder1(x)
+ enc2 = self.encoder2(self.pool1(enc1))
+ enc3 = self.encoder3(self.pool2(enc2))
+ enc4 = self.encoder4(self.pool3(enc3))
+
+ bottleneck = self.bottleneck(self.pool4(enc4))
+
+ dec4 = self.upconv4(bottleneck)
+ dec4 = torch.cat((dec4, enc4), dim=1)
+ dec4 = self.decoder4(dec4)
+ dec3 = self.upconv3(dec4)
+ dec3 = torch.cat((dec3, enc3), dim=1)
+ dec3 = self.decoder3(dec3)
+ dec2 = self.upconv2(dec3)
+ dec2 = torch.cat((dec2, enc2), dim=1)
+ dec2 = self.decoder2(dec2)
+ dec1 = self.upconv1(dec2)
+ dec1 = torch.cat((dec1, enc1), dim=1)
+ dec1 = self.decoder1(dec1)
+ x = self.conv(dec1)
+ return x
+
+ @staticmethod
+ def _block(in_channels, features, name):
+ return nn.Sequential(
+ OrderedDict(
+ [
+ (
+ name + "conv1",
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=features,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ),
+ (name + "norm1", nn.GroupNorm(4, features)),
+ (name + "tanh1", nn.Tanh()),
+ (
+ name + "conv2",
+ nn.Conv2d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ),
+ (name + "norm2", nn.GroupNorm(4, features)),
+ (name + "tanh2", nn.Tanh()),
+ (name + "conv3", nn.Conv2d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=1,
+ padding=0,
+ bias=True,
+ )),
+ ]
+ )
+ )
diff --git a/modules/commons/vqvae.py b/modules/commons/vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc259ad1ecbc4aca7397f25476f407c43a032a0
--- /dev/null
+++ b/modules/commons/vqvae.py
@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+from scipy.cluster.vq import kmeans2
+from torch.nn import functional as F
+
+
+class VQEmbeddingEMA(nn.Module):
+ def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5,
+ print_vq_prob=False):
+ super(VQEmbeddingEMA, self).__init__()
+ self.commitment_cost = commitment_cost
+ self.n_embeddings = n_embeddings
+ self.decay = decay
+ self.epsilon = epsilon
+ self.print_vq_prob = print_vq_prob
+ self.register_buffer('data_initialized', torch.zeros(1))
+
+ init_bound = 1 / 512
+ embedding = torch.Tensor(n_embeddings, embedding_dim)
+ embedding.uniform_(-init_bound, init_bound)
+ self.register_buffer("embedding", embedding)
+ self.register_buffer("ema_count", torch.zeros(n_embeddings))
+ self.register_buffer("ema_weight", self.embedding.clone())
+
+ def encode(self, x):
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ x_flat = x.detach().reshape(-1, D)
+
+ distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
+ torch.sum(x_flat ** 2, dim=1, keepdim=True),
+ x_flat, self.embedding.t(),
+ alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
+ indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
+ quantized = F.embedding(indices, self.embedding)
+ quantized = quantized.view_as(x)
+ return x_flat, quantized, indices
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, D]
+ :return: [B, T, D]
+ """
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ # if self.training and self.data_initialized.item() == 0:
+ # print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ # x_flat = x.detach().reshape(-1, D)
+ # rp = torch.randperm(x_flat.size(0))
+ # kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
+ # self.embedding.copy_(torch.from_numpy(kd[0]))
+ # x_flat, quantized, indices = self.encode(x)
+ # encodings = F.one_hot(indices, M).float()
+ # self.ema_weight.copy_(torch.matmul(encodings.t(), x_flat))
+ # self.ema_count.copy_(torch.sum(encodings, dim=0))
+
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ if self.training and self.data_initialized.item() != 0:
+ self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
+
+ n = torch.sum(self.ema_count)
+ self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n
+
+ dw = torch.matmul(encodings.t(), x_flat)
+ self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
+
+ self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
+
+ if self.training and self.data_initialized.item() == 0:
+ self.data_initialized.fill_(1)
+
+ e_latent_loss = F.mse_loss(x, quantized.detach(), reduction='none')
+ nonpadding = (x.abs().sum(-1) > 0).float()
+ e_latent_loss = (e_latent_loss.mean(-1) * nonpadding).sum() / nonpadding.sum()
+ loss = self.commitment_cost * e_latent_loss
+
+ quantized = x + (quantized - x).detach()
+
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ if self.print_vq_prob:
+ print("| VQ code avg_probs: ", avg_probs)
+ return quantized, loss, indices, perplexity
+
+
+class VQEmbedding(nn.Module):
+ def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, lambda_kl=1.0):
+ super(VQEmbedding, self).__init__()
+ self.commitment_cost = commitment_cost
+ self.lambda_kl = lambda_kl
+ self.n_embeddings = n_embeddings
+ embedding = torch.Tensor(n_embeddings, embedding_dim)
+ self.register_buffer("embedding", embedding)
+ self.register_buffer('data_initialized', torch.zeros(1))
+
+ def encode(self, x):
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ x_flat = x.detach().reshape(-1, D)
+
+ distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
+ torch.sum(x_flat ** 2, dim=1, keepdim=True),
+ x_flat, self.embedding.t(),
+ alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
+ indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
+ quantized = F.embedding(indices, self.embedding)
+ quantized = quantized.view_as(x)
+ return x_flat, quantized, indices
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, D]
+ :return: [B, T, D]
+ """
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ # DeepMind def does not do this but I find I have to... ;\
+ if self.training and self.data_initialized.item() == 0:
+ print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ rp = torch.randperm(x_flat.size(0))
+ kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
+ self.embedding.copy_(torch.from_numpy(kd[0]))
+ self.data_initialized.fill_(1)
+ # TODO: this won't work in multi-GPU setups
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ # vector quantization cost that trains the embedding vectors
+ loss = self.commitment_cost * (x.detach() - quantized).pow(2).mean() + \
+ (quantized - x.detach()).pow(2).mean()
+ loss *= self.lambda_kl
+
+ quantized = x + (quantized - x).detach()
+
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ return quantized, loss, indices, perplexity
diff --git a/modules/commons/vqvae_cvq.py b/modules/commons/vqvae_cvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..082039d3566b1b618d9bb54878122ab48de6cdbc
--- /dev/null
+++ b/modules/commons/vqvae_cvq.py
@@ -0,0 +1,190 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+import torch.distributed as dist
+
+from utils.commons.hparams import hparams
+
+
+class ClusteringVectorQuantiser(nn.Module):
+ """
+ Improved version over vector quantiser, with the dynamic initialisation
+ for these unoptimised "dead" points.
+ num_embed: number of codebook entry
+ embed_dim: dimensionality of codebook entry
+ beta: weight for the commitment loss
+ distance: distance for looking up the closest code
+ anchor: anchor sampled methods
+ first_batch: if true, the offline version of our model
+ contras_loss: if true, use the contras_loss to further improve the performance
+ """
+ def __init__(self, num_embed=1024, embed_dim=512, beta=0.25, distance='l2',
+ anchor='closest', first_batch=False, contras_loss=True):
+ super().__init__()
+
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+ self.beta = beta
+ self.distance = distance
+ self.anchor = anchor
+ self.first_batch = first_batch
+ self.contras_loss = contras_loss
+ self.decay = 0.99
+ self.init = False
+
+ self.pool = FeaturePool(self.num_embed, self.embed_dim)
+ self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
+ self.register_buffer("embed_prob", torch.zeros(self.num_embed))
+
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.embed_dim
+ z_flattened = z.view(-1, self.embed_dim)
+
+ # clculate the distance
+ if self.distance == 'l2':
+ # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
+ torch.sum(self.embedding.weight ** 2, dim=1) + \
+ 2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
+ elif self.distance == 'cos':
+ # cosine distances from z to embeddings e_j
+ normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
+ normed_codebook = F.normalize(self.embedding.weight, dim=1)
+ d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
+
+ # encoding
+ sort_distance, indices = d.sort(dim=1)
+ # look up the closest point for the indices
+ encoding_indices = indices[:,-1]
+ encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+
+ # quantise and unflatten
+ z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = self.beta * (z_q.detach() - z) ** 2 + (z_q - z.detach()) ** 2
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.embed_dim
+ else:
+ loss = loss.mean()
+ # loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+ # reshape back to match original input shape
+ # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+ # count
+ # import pdb
+ # pdb.set_trace()
+ avg_probs = torch.mean(encodings, dim=0)
+ # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ # min_encodings = encodings
+
+ # online clustered reinitialisation for unoptimized points
+ if self.training:
+ # calculate the average usage of code entries
+ self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
+ # running average updates
+ if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
+ # closest sampling
+ if self.anchor == 'closest':
+ sort_distance, indices = d.sort(dim=0)
+ random_feat = z_flattened.detach()[indices[-1,:]]
+ # feature pool based random sampling
+ elif self.anchor == 'random':
+ random_feat = self.pool.query(z_flattened.detach())
+ # probabilitical based random sampling
+ elif self.anchor == 'probrandom':
+ norm_distance = F.softmax(d.t(), dim=1)
+ prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
+ random_feat = z_flattened.detach()[prob]
+ # decay parameter based on the average usage
+ decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
+ if hparams.get('reduce_cvq_embed') and dist.is_initialized():
+ # 确保在所有GPU上同步embedding的权重
+ dist.all_reduce(random_feat.data, op=dist.ReduceOp.SUM)
+ random_feat.data /= dist.get_world_size()
+ self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
+ if self.first_batch:
+ self.init = True
+ # contrastive loss
+ if self.contras_loss:
+ sort_distance, indices = d.sort(dim=0)
+ dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
+ dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
+ dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
+ contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
+ loss += contra_loss
+
+ encoding_indices = encoding_indices.reshape(z.shape[:-1])
+ return z_q, loss, encoding_indices
+
+ def get_codebook_entry(self, encoding_indices):
+ # # get quantized latent vectors
+ # print(encoding_indices.shape)
+ # encoding_indices = encoding_indices.view(-1)
+ # encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=encoding_indices.device)
+ # print(encodings.shape)
+ # encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+ # print(encodings.shape)
+ # # quantise and unflatten
+ # z_q = torch.matmul(encodings, self.embedding.weight).view(encoding_indices.shape[0], -1)
+ z_q = self.embedding(encoding_indices)
+ return z_q
+
+class FeaturePool():
+ """
+ This class implements a feature buffer that stores previously encoded features
+
+ This buffer enables us to initialize the codebook using a history of generated features
+ rather than the ones produced by the latest encoders
+ """
+ def __init__(self, pool_size, dim=64):
+ """
+ Initialize the FeaturePool class
+
+ Parameters:
+ pool_size(int) -- the size of featue buffer
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.nums_features = 0
+ self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
+
+ def query(self, features):
+ """
+ return features from the pool
+ """
+ self.features = self.features.to(features.device)
+ if self.nums_features < self.pool_size:
+ if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ self.nums_features = self.pool_size
+ else:
+ # if the mini-batch is not large nuough, just store it for the next update
+ num = self.nums_features + features.size(0)
+ self.features[self.nums_features:num] = features
+ self.nums_features = num
+ else:
+ if features.size(0) > int(self.pool_size):
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ else:
+ random_id = torch.randperm(self.pool_size)
+ self.features[random_id[:features.size(0)]] = features
+
+ return self.features
\ No newline at end of file
diff --git a/modules/commons/vqvae_fsq.py b/modules/commons/vqvae_fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..12ade280e20a2f1cb9701e465e7335d45dee286a
--- /dev/null
+++ b/modules/commons/vqvae_fsq.py
@@ -0,0 +1,72 @@
+"""
+Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
+Code adapted from Jax version in Appendix A.1
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+from torch import Tensor, int32
+
+
+def round_ste(z: Tensor) -> Tensor:
+ """Round with straight through gradients."""
+ zhat = z.round()
+ return z + (zhat - z).detach()
+
+
+class FSQ(nn.Module):
+ def __init__(self, levels: List[int]):
+ super().__init__()
+ _levels = torch.tensor(levels, dtype=int32)
+ self.register_buffer("_levels", _levels)
+
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
+ self.register_buffer("_basis", _basis)
+
+ self.dim = len(levels)
+ self.n_codes = self._levels.prod().item()
+ implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
+ self.register_buffer("implicit_codebook", implicit_codebook)
+
+ def forward(self, z: Tensor) -> Tensor:
+ zhat = self.quantize(z)
+ indices = self.codes_to_indices(zhat)
+ return zhat, indices
+
+ def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
+ """Bound `z`, an array of shape (..., d)."""
+ half_l = (self._levels - 1) * (1 - eps) / 2
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
+ shift = (offset / half_l).tan()
+ return (z + shift).tanh() * half_l - offset
+
+ def quantize(self, z: Tensor) -> Tensor:
+ """Quantizes z, returns quantized zhat, same shape as z."""
+ quantized = round_ste(self.bound(z))
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
+ return quantized / half_width
+
+ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat_normalized * half_width) + half_width
+
+ def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat - half_width) / half_width
+
+ def codes_to_indices(self, zhat: Tensor) -> Tensor:
+ """Converts a `code` to an index in the codebook."""
+ assert zhat.shape[-1] == self.dim
+ zhat = self._scale_and_shift(zhat)
+ return (zhat * self._basis).sum(dim=-1).to(int32)
+
+ def indices_to_codes(self, indices: Tensor) -> Tensor:
+ """Inverse of `codes_to_indices`."""
+ indices = indices.unsqueeze(-1)
+ codes_non_centered = (indices // self._basis) % self._levels
+ return self._scale_and_shift_inverse(codes_non_centered)
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_lfq.py b/modules/commons/vqvae_lfq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9b0bf4837caa7bc6944952e02d4ce8e495f0bc
--- /dev/null
+++ b/modules/commons/vqvae_lfq.py
@@ -0,0 +1,276 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
+https://arxiv.org/abs/2309.15505
+"""
+
+from math import log2, ceil
+from collections import namedtuple
+
+import torch
+from torch import nn, Tensor, einsum
+import torch.nn.functional as F
+from torch.nn import Module
+
+from einops import rearrange, reduce, pack, unpack
+
+# constants
+
+# Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
+
+LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
+
+# helper functions
+
+def exists(v):
+ return v is not None
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg() if callable(arg) else arg
+ return None
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+# distance
+
+def euclidean_distance_squared(x, y):
+ x2 = reduce(x ** 2, '... n d -> ... n', 'sum')
+ y2 = reduce(y ** 2, 'n d -> n', 'sum')
+ xy = einsum('... i d, j d -> ... i j', x, y) * -2
+ return rearrange(x2, '... i -> ... i 1') + y2 + xy
+
+# entropy
+
+def log(t, eps = 1e-20):
+ return t.clamp(min = eps).log()
+
+def entropy(prob):
+ return -prob * log(prob)
+
+# class
+
+class LFQ(Module):
+ def __init__(
+ self,
+ *,
+ dim = None,
+ codebook_size = None,
+ entropy_loss_weight = 0.1,
+ commitment_loss_weight = 1.,
+ diversity_gamma = 2.5,
+ straight_through_activation = nn.Identity(),
+ num_codebooks = 1,
+ keep_num_codebooks_dim = None,
+ codebook_scale = 1. # for residual LFQ, codebook scaled down by 2x at each layer
+ ):
+ super().__init__()
+
+ # some assert validations
+
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
+
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
+ codebook_dim = int(log2(codebook_size))
+
+ codebook_dims = codebook_dim * num_codebooks
+ dim = default(dim, codebook_dims)
+
+ self.project_in = nn.Linear(dim, codebook_dims) if dim != codebook_dims else nn.Identity()
+ self.project_out = nn.Linear(codebook_dims, dim) if dim != codebook_dims else nn.Identity()
+
+ self.dim = dim
+ self.codebook_dim = codebook_dim
+ self.num_codebooks = num_codebooks
+
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+ # straight through activation
+
+ self.activation = straight_through_activation
+
+ # entropy aux loss related weights
+
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ # codebook scale
+
+ self.codebook_scale = codebook_scale
+
+ # commitment loss
+
+ self.commitment_loss_weight = commitment_loss_weight
+
+ # for no auxiliary loss, during inference
+
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
+
+ # codes
+
+ all_codes = torch.arange(codebook_size)
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
+ codebook = self.bits_to_codes(bits)
+
+ self.register_buffer('codebook', codebook, persistent = False)
+
+ def bits_to_codes(self, bits):
+ return bits * self.codebook_scale * 2 - self.codebook_scale
+
+ @property
+ def dtype(self):
+ return self.codebook.dtype
+
+ def indices_to_codes(
+ self,
+ indices,
+ project_out = True
+ ):
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... -> ... 1')
+
+ # indices to codes, which are bits of either -1 or 1
+
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
+
+ codes = self.bits_to_codes(bits)
+
+ codes = rearrange(codes, '... c d -> ... (c d)')
+
+ # whether to project codes out to original dimensions
+ # if the input feature dimensions were not log2(codebook size)
+
+ if project_out:
+ codes = self.project_out(codes)
+
+ # rearrange codes back to original shape
+
+ if is_img_or_video:
+ codes = rearrange(codes, 'b ... d -> b d ...')
+
+ return codes
+
+ def forward(
+ self,
+ x,
+ mask=None,
+ inv_temperature = 1.,
+ return_loss_breakdown = False
+ ):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ c - number of codebook dim
+ """
+
+ is_img_or_video = x.ndim >= 4
+
+ # standardize image or video into (batch, seq, dimension)
+
+ if is_img_or_video:
+ x = rearrange(x, 'b d ... -> b ... d')
+ x, ps = pack_one(x, 'b * d')
+
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
+
+ x = self.project_in(x)
+
+ # split out number of codebooks
+
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
+
+ # quantize by eq 3.
+
+ original_input = x
+
+ codebook_value = torch.ones_like(x) * self.codebook_scale
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
+
+ # use straight-through gradients with tanh (or custom activation fn) if training
+
+ if self.training:
+ x = self.activation(x)
+ x = x - x.detach() + quantized
+ else:
+ x = quantized
+
+ # calculate indices
+
+ indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
+
+ # entropy aux loss
+
+ if self.training:
+ distance = euclidean_distance_squared(original_input, self.codebook)
+
+ prob = (-distance * inv_temperature).softmax(dim = -1)
+
+ per_sample_entropy = entropy(prob).mean()
+
+ avg_prob = reduce(prob, 'b n c d -> b c d', 'mean')
+ codebook_entropy = entropy(avg_prob).mean()
+
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
+
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
+
+ # commit loss
+
+ if self.training:
+ commit_loss = F.mse_loss(original_input, quantized.detach())
+ else:
+ commit_loss = self.zero
+
+ # merge back codebook dim
+
+ x = rearrange(x, 'b n c d -> b n (c d)')
+
+ # project out to feature dimension if needed
+
+ x = self.project_out(x)
+
+ # reconstitute image or video dimensions
+
+ if is_img_or_video:
+ x = unpack_one(x, ps, 'b * d')
+ x = rearrange(x, 'b ... d -> b d ...')
+
+ indices = unpack_one(indices, ps, 'b * c')
+
+ # whether to remove single codebook dim
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... 1 -> ...')
+
+ # complete aux loss
+
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
+
+ ret = x, aux_loss, indices
+
+ if not return_loss_breakdown:
+ return ret
+
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_lfq_y.py b/modules/commons/vqvae_lfq_y.py
new file mode 100644
index 0000000000000000000000000000000000000000..b34ead5d2481801a6a966b7d560b326e8083e310
--- /dev/null
+++ b/modules/commons/vqvae_lfq_y.py
@@ -0,0 +1,109 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
+https://arxiv.org/abs/2309.15505
+"""
+
+import torch
+from einops import rearrange
+from torch.nn import Module
+
+
+# entropy
+
+def binary_entropy(prob):
+ return -prob * log(prob) - (1 - prob) * log(1 - prob)
+
+
+# tensor helpers
+
+def log(t, eps=1e-20):
+ return t.clamp(min=eps).log()
+
+
+# convert to bit representations and back
+
+def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor:
+ # [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1}
+ mask = 2 ** torch.arange(bits).to(x) # [d]
+ bits = ((x.unsqueeze(-1) & mask) != 0).float() # [b, n, d] {0, 1}
+ return bits * 2 - 1 # {0, 1} -> {-1, 1}
+
+
+def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor:
+ # [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1}
+ x = (x > 0).long() # {-1, 1} -> {0, 1}, [b, ..., d]
+ mask = 2 ** torch.arange(x.size(-1)).to(x) # [d]
+ dec = (x * mask).sum(-1) # [b, ...]
+ return dec
+
+
+# class
+
+class LFQY(Module):
+ def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0):
+ super().__init__()
+ self.dim = dim
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ def indices_to_codes(self, indices):
+ codes = decimal_to_bits(indices, self.dim)
+ # codes = rearrange(codes, 'b ... d -> b d ...')
+ return codes
+
+ def forward(self, x, mask=None, inv_temperature=1.):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ """
+ # x = rearrange(x, 'b d ... -> b ... d')
+
+ assert x.shape[-1] == self.dim
+ z = torch.tanh(x / inv_temperature) # (-1, 1)
+
+ # quantize by eq 3.
+ quantized = torch.sign(x) # {-1, 1}
+ z = z + (quantized - z).detach()
+
+ # calculate indices
+ indices = bits_to_decimal(z)
+
+ # entropy aux loss
+ if self.training:
+ prob = torch.sigmoid(x / inv_temperature) # [b, ..., d]
+
+ bit_entropy = binary_entropy(prob).sum(-1).mean()
+ # E[H(q)] = avg(sum(H(q_i)))
+
+ avg_prob = prob.flatten(0, -2).mean(0) # [b, ..., d] -> [n, d] -> [d]
+ codebook_entropy = binary_entropy(avg_prob).sum()
+ # H(E[q]) = sum(H(avg(q_i)))
+
+ """
+ 1. entropy will be nudged to be low for each bit,
+ so each scalar commits to one latent binary bit or the other.
+ 2. codebook entropy will be nudged to be high,
+ to encourage all codes to be uniformly used.
+ """
+
+ entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = torch.zeros(1).to(z)
+
+ entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
+
+ # reconstitute image or video dimensions
+
+ # z = rearrange(z, 'b ... d -> b d ...')
+
+ # bits to decimal for the codebook indices
+ return z, entropy_aux_loss, indices
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_taming.py b/modules/commons/vqvae_taming.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b7abff0050186aacdd5899f142c5dcbcf49295
--- /dev/null
+++ b/modules/commons/vqvae_taming.py
@@ -0,0 +1,428 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from scipy.cluster.vq import kmeans2
+from torch import einsum
+from einops import rearrange
+import torch.distributed as dist
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # .........\end
+
+ # with:
+ # .........\start
+ # min_encoding_indices = torch.argmin(d, dim=1)
+ # z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:, None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, legacy=False):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.re_embed = n_e
+
+ def encode(self, z):
+ B, T, _ = z.shape
+ z_flattened = z.reshape(-1, self.e_dim)
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+
+ z_q = z_q.view_as(z)
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ return z_flattened, z_q, min_encoding_indices
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.e_dim
+ z_flattened = z.reshape(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.matmul(z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+ #torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ perplexity = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * (z_q.detach() - z) ** 2 + \
+ (z_q - z.detach()) ** 2
+ else:
+ loss = (z_q.detach() - z) ** 2 + self.beta * \
+ (z_q - z.detach()) ** 2
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.e_dim
+ else:
+ loss = loss.mean()
+ return z_q, loss, min_encoding_indices, perplexity
+
+ def get_codebook_entry(self, indices, shape=None):
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class VectorQuantizer4(nn.Module):
+ def __init__(self, n_e, e_dim, beta, legacy=False, kmeans_reset_every=1000):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.re_embed = n_e
+ self.reset_every = kmeans_reset_every
+ self.reset_thres = 20
+ self.z_buffer = []
+ self.register_buffer('use_flag', torch.zeros(n_e))
+ self.register_buffer('steps', torch.zeros(1))
+
+ def encode(self, z):
+ B, T, _ = z.shape
+ z_flattened = z.reshape(-1, self.e_dim)
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+
+ z_q = z_q.view_as(z)
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ return z_flattened, z_q, min_encoding_indices
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.e_dim
+ z_flattened = z.reshape(-1, self.e_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ perplexity = None
+
+ if self.training:
+ self.steps += 1
+ self.use_flag += torch.bincount(min_encoding_indices, minlength=self.n_e)
+ is_master = not dist.is_initialized() or dist.get_rank() == 0
+ if self.reset_every - 100 <= self.steps <= self.reset_every:
+ if dist.is_initialized():
+ z_buffer_ = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(z_buffer_, z_flattened.detach().cpu())
+ else:
+ z_buffer_ = [z_flattened.detach().cpu()]
+ self.z_buffer += z_buffer_
+
+ if self.steps % self.reset_every == 0:
+ if dist.is_initialized():
+ dist.all_reduce(self.use_flag)
+ vq_usage = (self.use_flag > self.reset_thres).sum().item() / self.use_flag.shape[0]
+ print("| VQ usage: ", vq_usage)
+ if vq_usage != 1:
+ if is_master:
+ if self.steps.item() == self.reset_every:
+ print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ z_buffer = torch.cat(self.z_buffer, 0)
+ rp = torch.randperm(z_buffer.shape[0])
+ kd = kmeans2(z_buffer[rp].numpy(), self.n_e, minit='points')[0]
+ self.embedding.weight.data = torch.from_numpy(kd).to(z.device)
+ else:
+ reset_ids = self.use_flag < self.reset_thres
+ keep_ids = self.use_flag >= self.reset_thres
+ t = torch.randint(0, keep_ids.sum(), [reset_ids.sum()], device=self.use_flag.device)
+ keep_ids = torch.where(keep_ids)[0][t]
+ self.embedding.weight.data[reset_ids] = self.embedding.weight.data[keep_ids].clone()
+ if dist.is_initialized():
+ dist.broadcast(self.embedding.weight.data, 0)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ self.use_flag.fill_(0)
+ self.z_buffer = []
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * (z_q.detach() - z) ** 2 + \
+ (z_q - z.detach()) ** 2
+ else:
+ loss = (z_q.detach() - z) ** 2 + self.beta * \
+ (z_q - z.detach()) ** 2
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.e_dim
+ else:
+ loss = loss.mean()
+ return z_q, loss, min_encoding_indices, perplexity
+
+ def get_codebook_entry(self, indices, shape=None):
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
diff --git a/modules/commons/wavenet.py b/modules/commons/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7809c9b9d3331ba4fd2ffd4caae14e721e4b0732
--- /dev/null
+++ b/modules/commons/wavenet.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nn
+
+
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
+ p_dropout=0, share_cond_layers=False, is_BTC=False):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ assert (hidden_size % 2 == 0)
+ self.is_BTC = is_BTC
+ self.hidden_size = hidden_size
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = c_cond
+ self.p_dropout = p_dropout
+ self.share_cond_layers = share_cond_layers
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if c_cond != 0 and not share_cond_layers:
+ cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_size
+ else:
+ res_skip_channels = hidden_size
+
+ res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, nonpadding=None, cond=None):
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2) if cond is not None else None
+ nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
+ if nonpadding is None:
+ nonpadding = 1
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_size])
+
+ if cond is not None and not self.share_cond_layers:
+ cond = self.cond_layer(cond)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ x_in = self.drop(x_in)
+ if cond is not None:
+ cond_offset = i * 2 * self.hidden_size
+ cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
+ else:
+ cond_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
+ output = output + res_skip_acts[:, self.hidden_size:, :]
+ else:
+ output = output + res_skip_acts
+ output = output * nonpadding
+ if self.is_BTC:
+ output = output.transpose(1, 2)
+ return output
+
+ def remove_weight_norm(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
diff --git a/modules/eg3ds/camera_utils/pose_sampler.py b/modules/eg3ds/camera_utils/pose_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e36f3bcac364ab993ad59f25ce7f90726f32ceb
--- /dev/null
+++ b/modules/eg3ds/camera_utils/pose_sampler.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""
+Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
+"""
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+
+from modules.eg3ds.volumetric_rendering import math_utils
+
+
+class UnifiedCameraPoseSampler():
+ """
+ A unified class for obtain camera pose, a 25 dimension vector that consists of camera2world matrix (4x4) and camera intrinsic (3,3)
+ it utilize the samplers constructed below.
+ """
+ def get_camera_pose(self, pitch, yaw, lookat_location=None, distance_to_orig=2.7, batch_size=1, device='cpu', roll=None):
+ if lookat_location is None:
+ lookat_location = torch.tensor([0., 0., -0.2], device=device)
+
+ c2w = LookAtPoseSampler.sample(yaw, pitch, lookat_location, 0, 0, distance_to_orig, batch_size, device, roll=roll).reshape([batch_size, 16])
+ intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device).reshape([9,]).unsqueeze(0).repeat([batch_size, 1])
+ # intrinsics = FOV_to_intrinsics(fov_degrees, device=device).reshape([9,]).unsqueeze(0).repeat([batch_size, 1])
+ camera = torch.cat([c2w, intrinsics], dim=1) # [batch, 25]
+ return camera
+
+
+class GaussianCameraPoseSampler:
+ """
+ Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
+ Camera is specified as looking at the origin.
+ If horizontal and vertical stddev (specified in radians) are zero, gives a
+ deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
+ The coordinate system is specified with y-up, z-forward, x-left.
+ Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
+ vertical mean is the polar angle (angle from the y axis) in radians.
+ A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ assert horizontal_mean < np.pi/2 + 1e-5 and horizontal_mean > - np.pi/2 - 1e-5
+ assert vertical_mean < np.pi/2 + 1e-5 and vertical_mean > - np.pi/2 - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins) # the direction the camera is pointing, pointing to origin in this func
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+class LookAtPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ camera is specified as looking at 'lookat_position', a 3-vector.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu', roll=None):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ # assert horizontal_mean < np.pi + 1e-5 and horizontal_mean > - np.pi - 1e-5
+ # assert vertical_mean < np.pi + 1e-5 and vertical_mean > - np.pi - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+
+ # if horizontal_mean < -np.pi:
+ # horizontal_mean += 2*np.pi
+ # if vertical_mean < -np.pi:
+ # vertical_mean += 2*np.pi
+ # if horizontal_mean > np.pi:
+ # horizontal_mean -= 2*np.pi
+ # if vertical_mean > np.pi:
+ # vertical_mean -= 2*np.pi
+
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h # 球坐标系里的滚转角
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ # radius*torch.sin(phi) 是球半径在水平平面上的投影,随后再根据yaw角来分别计算x和y
+ # radius*torch.cos(phi)则是纵轴的分量
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ # forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ forward_vectors = math_utils.normalize_vecs(lookat_position.to(device) - camera_origins) # the direction the camera is pointing, pointing to the lookat_position
+ return create_cam2world_matrix(forward_vectors, camera_origins, roll)
+
+
+class UniformCameraPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ pose is sampled from a UNIFORM distribution with range +-[horizontal/vertical]_stddev, instead of a GAUSSIAN distribution.
+
+ Example:
+ For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
+
+ cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ assert horizontal_mean < np.pi/2 + 1e-5 and horizontal_mean > - np.pi/2 - 1e-5
+ assert vertical_mean < np.pi/2 + 1e-5 and vertical_mean > - np.pi/2 - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+
+ h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
+ v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device) # the location of camera
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins) # the direction the camera is pointing, pointing to origin in this func
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+def create_cam2world_matrix(forward_vector, origin, roll=None):
+ """
+ Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
+ Works on batches of forward_vectors, origins. Assumes y-axis is up.
+ Modified by yerfor to support roll controll
+ roll: Default None, leads to 0 roll; or Tensor([Batch_size, 1]), with radian in [-pi, pi]
+ """
+
+ batch_size = len(forward_vector)
+ forward_vector = math_utils.normalize_vecs(forward_vector)
+ # up_vector 代表相机的正上方方向向量,所以可以通过旋转它来控制roll
+ up_vector = torch.zeros([batch_size, 3], dtype=forward_vector.dtype, device=forward_vector.device)
+ if roll is None:
+ roll = torch.zeros([batch_size, 1], dtype=forward_vector.dtype, device=forward_vector.device)
+ else:
+ roll = roll.reshape([batch_size, 1])
+
+ up_vector[:, 0] = torch.sin(roll)
+ up_vector[:, 1] = torch.cos(roll)
+
+ right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
+ up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
+
+ rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
+
+ translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ translation_matrix[:, :3, 3] = origin
+ cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
+ assert(cam2world.shape[1:] == (4, 4))
+ return cam2world
+
+
+def FOV_to_intrinsics(fov_degrees=18.837, device='cpu'):
+ """
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
+ Assumes principal point is at image center.
+ """
+
+ focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+ return intrinsics
\ No newline at end of file
diff --git a/modules/eg3ds/dnnlib/__init__.py b/modules/eg3ds/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd91ed142e955581e83948455fb71cd837215f61
--- /dev/null
+++ b/modules/eg3ds/dnnlib/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/modules/eg3ds/dnnlib/util.py b/modules/eg3ds/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b67c4e312cd1b847ca21fd3b929802a57e6f6d
--- /dev/null
+++ b/modules/eg3ds/dnnlib/util.py
@@ -0,0 +1,493 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/modules/eg3ds/metrics/__init__.py b/modules/eg3ds/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/metrics/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/metrics/equivariance.py b/modules/eg3ds/metrics/equivariance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d105cb93031d5a9638d7a9c12c65db1d8c4a0860
--- /dev/null
+++ b/modules/eg3ds/metrics/equivariance.py
@@ -0,0 +1,270 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import copy
+import numpy as np
+import torch
+import torch.fft
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+# Utilities.
+
+def sinc(x):
+ y = (x * np.pi).abs()
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
+
+def lanczos_window(x, a):
+ x = x.abs() / a
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
+
+def rotation_matrix(angle):
+ angle = torch.as_tensor(angle).to(torch.float32)
+ mat = torch.eye(3, device=angle.device)
+ mat[0, 0] = angle.cos()
+ mat[0, 1] = angle.sin()
+ mat[1, 0] = -angle.sin()
+ mat[1, 1] = angle.cos()
+ return mat
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.1.
+
+def apply_integer_translation(x, tx, ty):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.round().to(torch.int64)
+ iy = ty.round().to(torch.int64)
+
+ z = torch.zeros_like(x)
+ m = torch.zeros_like(x)
+ if abs(ix) < W and abs(iy) < H:
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.2.
+
+def apply_fractional_translation(x, tx, ty, a=3):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.floor().to(torch.int64)
+ iy = ty.floor().to(torch.int64)
+ fx = tx - ix
+ fy = ty - iy
+ b = a - 1
+
+ z = torch.zeros_like(x)
+ zx0 = max(ix - b, 0)
+ zy0 = max(iy - b, 0)
+ zx1 = min(ix + a, 0) + W
+ zy1 = min(iy + a, 0) + H
+ if zx0 < zx1 and zy0 < zy1:
+ taps = torch.arange(a * 2, device=x.device) - b
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
+ y = x
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
+ z[:, :, zy0:zy1, zx0:zx1] = y
+
+ m = torch.zeros_like(x)
+ mx0 = max(ix + a, 0)
+ my0 = max(iy + a, 0)
+ mx1 = min(ix - b, 0) + W
+ my1 = min(iy - b, 0) + H
+ if mx0 < mx1 and my0 < my1:
+ m[:, :, my0:my1, mx0:mx1] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Construct an oriented low-pass filter that applies the appropriate
+# bandlimit with respect to the input and output of the given affine 2D
+# image transformation.
+
+def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
+ assert a <= amax < aflt
+ mat = torch.as_tensor(mat).to(torch.float32)
+
+ # Construct 2D filter taps in input & output coordinate spaces.
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
+ yi, xi = torch.meshgrid(taps, taps)
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
+
+ # Convolution of two oriented 2D sinc filters.
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
+
+ # Convolution of two oriented 2D Lanczos windows.
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
+
+ # Construct windowed FIR filter.
+ f = f * w
+
+ # Finalize.
+ c = (aflt - amax) * up
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
+ return f
+
+#----------------------------------------------------------------------------
+# Apply the given affine transformation to a batch of 2D images.
+
+def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
+ _N, _C, H, W = x.shape
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
+
+ # Construct filter.
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
+ p = f.shape[0] // 2
+
+ # Construct sampling grid.
+ theta = mat.inverse()
+ theta[:2, 2] *= 2
+ theta[0, 2] += 1 / up / W
+ theta[1, 2] += 1 / up / H
+ theta[0, :] *= W / (W + p / up * 2)
+ theta[1, :] *= H / (H + p / up * 2)
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
+
+ # Resample image.
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+ # Form mask.
+ m = torch.zeros_like(y)
+ c = p * 2 + 1
+ m[:, :, c:-c, c:-c] = 1
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply fractional rotation to a batch of 2D images. Corresponds to the
+# operator R_\alpha in Appendix E.3.
+
+def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(angle)
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
+
+#----------------------------------------------------------------------------
+# Modify the frequency content of a batch of 2D images as if they had undergo
+# fractional rotation -- but without actually rotating them. Corresponds to
+# the operator R^*_\alpha in Appendix E.3.
+
+def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(-angle)
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
+ y = upfirdn2d.filter2d(x=x, f=f)
+ m = torch.zeros_like(y)
+ c = f.shape[0] // 2
+ m[:, :, c:-c, c:-c] = 1
+ return y, m
+
+#----------------------------------------------------------------------------
+# Compute the selected equivariance metrics for the given generator.
+
+def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ I = torch.eye(3, device=opts.device)
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
+ if M is None:
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ sums = None
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ s = []
+
+ # Randomize noise buffers, if any.
+ for name, buf in G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Run mapping network.
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
+ c = next(c_iter)
+ ws = G.mapping(z=z, c=c)
+
+ # Generate reference image.
+ M[:] = I
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+
+ # Integer translation (EQ-T).
+ if compute_eqt_int:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ t = (t * G.img_resolution).round() / G.img_resolution
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Fractional translation (EQ-T_frac).
+ if compute_eqt_frac:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Rotation (EQ-R).
+ if compute_eqr:
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
+ M[:] = rotation_matrix(-angle)
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
+ mask = ref_mask * pseudo_mask
+ s += [(ref - pseudo).square() * mask, mask]
+
+ # Accumulate results.
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
+ sums = sums + s if sums is not None else s
+ progress.update(num_samples)
+
+ # Compute PSNRs.
+ if opts.num_gpus > 1:
+ torch.distributed.all_reduce(sums)
+ sums = sums.cpu()
+ mses = sums[0::2] / sums[1::2]
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
+ psnrs = tuple(psnrs.numpy())
+ return psnrs[0] if len(psnrs) == 1 else psnrs
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/frechet_inception_distance.py b/modules/eg3ds/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..e682de6162066e255b04c0db2f1cc8860c96de7c
--- /dev/null
+++ b/modules/eg3ds/metrics/frechet_inception_distance.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
+
diff --git a/modules/eg3ds/metrics/inception_score.py b/modules/eg3ds/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8887595d5d563d391a9f95f193081e70d11caba
--- /dev/null
+++ b/modules/eg3ds/metrics/inception_score.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/kernel_inception_distance.py b/modules/eg3ds/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a7735f387fb639135a0dd9a63be6b24c9bb3ade
--- /dev/null
+++ b/modules/eg3ds/metrics/kernel_inception_distance.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/metric_main.py b/modules/eg3ds/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..77eadbef168888cd740abb2e638ee111ef15c559
--- /dev/null
+++ b/modules/eg3ds/metrics/metric_main.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Main API for computing and reporting quality metrics."""
+
+import os
+import time
+import json
+import torch
+import modules.eg3ds.dnnlib as dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+from . import equivariance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Recommended metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def eqt50k_int(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
+ return dict(eqt50k_int=psnr)
+
+@register_metric
+def eqt50k_frac(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
+ return dict(eqt50k_frac=psnr)
+
+@register_metric
+def eqr50k(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
+ return dict(eqr50k=psnr)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/metric_utils.py b/modules/eg3ds/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..756169b281ff0cf72bbacb879bafccc2721b5d42
--- /dev/null
+++ b/modules/eg3ds/metrics/metric_utils.py
@@ -0,0 +1,324 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utilities used internally by the quality metrics."""
+
+import os
+import sys
+sys.path.append("/home/tiger/projects/GeneFace_private/modules/eg3ds")
+
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import modules.eg3ds.dnnlib as dnnlib
+
+from tasks.eg3ds.dataset_utils.kv_eg3d_ffhq_dataset import KV_FFHQ_EG3D_Dataset
+from utils.commons.hparams import hparams
+#----------------------------------------------------------------------------
+
+def chunk(iterable, chunk_size):
+ final_ret = []
+ cnt = 0
+ ret = []
+ for record in iterable:
+ if cnt == 0:
+ ret = []
+ ret.append(record)
+ cnt += 1
+ if len(ret) == chunk_size:
+ final_ret.append(ret)
+ ret = []
+ if len(final_ret[-1]) != chunk_size:
+ final_ret.append(ret)
+ return final_ret
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+def iterate_random_labels(opts, batch_size):
+ if opts.G.c_dim == 0:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ while True:
+ yield c
+ else:
+ # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if hparams['ds_name'] in ['FFHQ']:
+ dataset = KV_FFHQ_EG3D_Dataset('train', shuffle=False)
+ else:
+ raise NotImplementedError()
+ while True:
+ # c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ # c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ index = np.random.randint(len(dataset), size=(batch_size))
+ samples = dataset[index]
+ cameras = [s['real_camera'] for s in samples]
+ c = torch.stack(cameras).pin_memory().to(opts.device)
+ yield c
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if hparams['ds_name'] in ['FFHQ']:
+ dataset = KV_FFHQ_EG3D_Dataset('train', shuffle=False)
+ else:
+ raise NotImplementedError()
+
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ ds_name = hparams['ds_name'] + dataset.prefix
+ cache_tag = f'{ds_name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ item_subset = chunk(item_subset, chunk_size=batch_size)
+ for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=1, collate_fn=dataset.collater, **data_loader_kwargs):
+ images = batch['real_imgs']
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+
+ if images.dtype != torch.uint8:
+ images = (images * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ img = G(z=z, camera=next(c_iter))['image']
+ # img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ images.append(img)
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/perceptual_path_length.py b/modules/eg3ds/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e58dac3317733e2ace6d64ee1f97cafa0a38225
--- /dev/null
+++ b/modules/eg3ds/metrics/perceptual_path_length.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler and labels.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ x = sampler(next(c_iter))
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/precision_recall.py b/modules/eg3ds/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..6043717d59c53c34d76e35600a58f91e77659e0c
--- /dev/null
+++ b/modules/eg3ds/metrics/precision_recall.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/vgg16.pkl'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/dual_discriminator.py b/modules/eg3ds/models/dual_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d82d4148253a341cf3bccf7bd056a39be00e22
--- /dev/null
+++ b/modules/eg3ds/models/dual_discriminator.py
@@ -0,0 +1,374 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Discriminator architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+#
+
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue
+from einops import rearrange
+from utils.commons.hparams import hparams
+
+
+class SingleDiscriminator(torch.nn.Module):
+ def __init__(self,
+ img_resolution, # Input resolution.
+ img_channels =3, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ sr_upsample_factor = 1, # Ignored for SingleDiscriminator
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.camera_dim = 25
+ if hparams['disc_cond_mode'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+ self.c_dim = c_dim
+
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ img = img['image']
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
+ is_bcthw_flag = True if image_orig_tensor.ndim == 5 else False
+ if is_bcthw_flag: # [B, c, T, H, W]
+ n,c,t,h,w = image_orig_tensor.shape
+ image_orig_tensor = rearrange(image_orig_tensor, "n c t h w -> (n t) c h w")
+
+ if filter_mode == 'antialiased':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ elif filter_mode == 'classic':
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
+ elif filter_mode == 'none':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
+ elif type(filter_mode) == float:
+ assert 0 < filter_mode < 1
+
+ filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False)
+ ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
+ if is_bcthw_flag: # [B, c, T, H, W]
+ ada_filtered_64 = rearrange(ada_filtered_64, "(n t) c h w -> n c t h w", n=n,t=t)
+
+ return ada_filtered_64
+
+#----------------------------------------------------------------------------
+
+class DualDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ channel_base = hparams['base_channel']
+ channel_max = hparams['max_channel']
+ conv_clamp = 256
+ cmap_dim = None
+ block_kwargs = {'freeze_layers': 0}
+ mapping_kwargs = {}
+ epilogue_kwargs = {'mbstd_group_size': hparams['group_size_for_mini_batch_std']}
+ architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'.
+
+ img_channels = 3
+ img_channels *= 2
+
+ self.camera_dim = 25
+ c_dim = self.camera_dim
+
+ self.img_resolution = hparams['final_resolution']
+ self.img_resolution_log2 = int(np.log2(self.img_resolution))
+ self.img_channels = 3
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ self.num_fp16_res = hparams['num_fp16_layers_in_discriminator']
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - self.num_fp16_res), 8)
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < self.img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ # use_fp16 = True
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ if hparams.get("disc_cond_mode", 'none') != 'none':
+ """
+ For discriminator, embed cond with mapping network works well.
+ """
+ self.cond_dim = 204
+ self.mapping = MappingNetwork(z_dim=self.cond_dim, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, img, camera, cond=None, update_emas=False, feature_maps=None, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ img = torch.cat([img['image'], image_raw], 1)
+
+ # add by yerfor
+ img = torch.clamp(img, min=-1, max=1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+ if feature_maps is not None:
+ feature_maps.append(x)
+ cmap = None
+
+ c = camera.clone() # prevent inplace modification in sample!
+ if hparams['disc_c_noise'] > 0:
+ if len(c) > 1:
+ c_std = c.std(0)
+ else:
+ # c_std = 1
+ c_std = torch.tensor([0.0664, 0.0295, 0.2720, 0.6971, 0.0279, 0.0178, 0.1280, 0.3284, 0.2721,
+ 0.1274, 0.0679, 0.1642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0079, 0.0000,
+ 0.0000, 0.0000, 0.0079, 0.0000, 0.0000, 0.0000, 0.0000]).to(c.device)
+ c += torch.randn_like(c) * c_std * hparams['disc_c_noise']
+
+ # x: [B, 512, 4, 4], img: None, cmap: [B, 512]
+ if hparams.get("disc_cond_mode", 'none') != 'none':
+ cmap = self.mapping(cond, c)
+ else:
+ cmap = self.mapping(None, c)
+
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+class DummyDualDiscriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ self.raw_fade = 1
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
+
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+
+# Tri-discriminator: upsampled image, super-resolved image, and segmentation mask
+# V2: first concatenate imgs and seg mask, using only one conv block
+class MaskDualDiscriminatorV2(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ seg_resolution = 128, # Input resolution.
+ seg_channels = 1, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels = img_channels * 2 + seg_channels
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.seg_resolution = seg_resolution
+ self.seg_channels = seg_channels
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = 2 * seg - 1 # normalize to [-1,1]
+ img = torch.cat([img['image'], image_raw, seg], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'c_dim={self.c_dim:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'seg_resolution={self.seg_resolution:d}, seg_channels={self.seg_channels:d}'])
\ No newline at end of file
diff --git a/modules/eg3ds/models/dual_discriminator_cond.py b/modules/eg3ds/models/dual_discriminator_cond.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d6b37470054d002607f05fb764988d160272c80
--- /dev/null
+++ b/modules/eg3ds/models/dual_discriminator_cond.py
@@ -0,0 +1,279 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Discriminator architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue
+from modules.eg3ds.models.cond_encoder import LM3D_Win_Encoder
+
+from utils.commons.hparams import hparams
+
+
+class SingleDiscriminator(torch.nn.Module):
+ def __init__(self,
+ img_resolution, # Input resolution.
+ img_channels =3, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ sr_upsample_factor = 1, # Ignored for SingleDiscriminator
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.camera_dim = 25
+ if hparams['cond_type'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+ if self.cond_dim > 0:
+ cond_out_dim = hparams['cond_out_dim']
+ c_dim += cond_out_dim
+ self.cond_encoder = LM3D_Win_Encoder(self.cond_dim, hid_dim=hparams['cond_hid_dim'], out_dim=cond_out_dim, smo_size=hparams['smo_win_size'])
+ self.c_dim = c_dim
+
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ img = img['image']
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
+ if filter_mode == 'antialiased':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ elif filter_mode == 'classic':
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
+ elif filter_mode == 'none':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
+ elif type(filter_mode) == float:
+ assert 0 < filter_mode < 1
+
+ filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False)
+ ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
+
+ return ada_filtered_64
+
+#----------------------------------------------------------------------------
+
+class DualDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ channel_base = hparams['base_channel']
+ channel_max = hparams['max_channel']
+ conv_clamp = 256
+ cmap_dim = None
+ disc_c_noise = 0.
+ block_kwargs = {'freeze_layers': 0}
+ mapping_kwargs = {}
+ epilogue_kwargs = {'mbstd_group_size': 4}
+ architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'.
+
+ img_channels = 3
+ img_channels *= 2
+
+ self.camera_dim = 25
+ if hparams['cond_type'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+
+ if self.cond_dim > 0:
+ cond_out_dim = hparams['cond_out_dim']
+ c_dim += cond_out_dim
+ self.cond_encoder = LM3D_Win_Encoder(self.cond_dim, hid_dim=hparams['cond_hid_dim'], out_dim=cond_out_dim, smo_size=hparams['smo_win_size'])
+
+ self.img_resolution = hparams['final_resolution']
+ self.img_resolution_log2 = int(np.log2(self.img_resolution))
+ self.img_channels = 3
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ self.num_fp16_res = hparams['num_fp16_layers_in_discriminator']
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - self.num_fp16_res), 8)
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < self.img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+ if self.disc_c_noise > 0:
+ c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+class DummyDualDiscriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ self.raw_fade = 1
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
+
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/networks_stylegan2.py b/modules/eg3ds/models/networks_stylegan2.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d638e3f6898cec32c92c475ad7a73df12e8f9c
--- /dev/null
+++ b/modules/eg3ds/models/networks_stylegan2.py
@@ -0,0 +1,814 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Network architectures from the paper
+"Analyzing and Improving the Image Quality of StyleGAN".
+Matches the original implementation of configs E-F by Karras et al. at
+https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+from modules.eg3ds.torch_utils import misc
+from modules.eg3ds.torch_utils.ops import conv2d_resample
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.torch_utils.ops import bias_act
+from modules.eg3ds.torch_utils.ops import fma
+
+from utils.commons.hparams import hparams
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise = None, # Optional noise tensor to add to the output activations.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ padding = 0, # Padding with respect to the upsampled image.
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate = True, # Apply weight demodulation?
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk], 将weight乘以style
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] # [2, 512,512,3,3]==>[2, 512] 归一化
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) # 将x乘以style
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) # conv2d forward
+ if demodulate and noise is not None:
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) # FusedMultiplyAdd
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+#----------------------------------------------------------------------------
+
+
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.bias_init = bias_init
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+#----------------------------------------------------------------------------
+
+
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+ self.trainable = trainable
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain
+
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
+ f'up={self.up}, down={self.down}'])
+
+#----------------------------------------------------------------------------
+
+
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ last_activation = None, # add by panohead, define the last activation
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ if idx == num_layers - 1 and last_activation:
+ layer = FullyConnectedLayer(in_features, out_features, activation=last_activation, lr_multiplier=lr_multiplier)
+ else:
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if update_emas and self.w_avg_beta is not None:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation trick.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) # 从w_avg出发向x前进,前进步数[0~1.]为truncation_psi
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size = 3, # Convolution kernel size.
+ up = 1, # Integer upsampling factor.
+ use_noise = True, # Enable noise input?
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last = False, # Use channels_last format for the weights?
+ **other_args
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, **kwargs):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ weight = self.weight
+ x = modulated_conv2d(x=x, weight=weight, styles=styles, noise=noise, up=self.up,
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
+
+#----------------------------------------------------------------------------
+
+
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ weight = self.weight
+ x = modulated_conv2d(x=x, weight=weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) # demodulate为False
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+ def extra_repr(self):
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
+
+#----------------------------------------------------------------------------
+
+class SynthesisBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.num_fp16_res = num_fp16_res
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ self.num_ws = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
+ self.num_ws += block.num_conv
+ if is_last:
+ self.num_ws += block.num_torgb
+ setattr(self, f'b{res}', block)
+
+ def forward(self, ws, **block_kwargs):
+ block_ws = []
+ with torch.autograd.profiler.record_function('split_ws'):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32)
+ w_idx = 0
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) # [B, num_conv_and_rgb, w_dim]
+ w_idx += block.num_conv
+
+ x = img = None
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, cur_ws, **block_kwargs)
+ return img
+
+ def extra_repr(self):
+ return ' '.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_fp16_res={self.num_fp16_res:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+ if hparams.get("gen_cond_mode", 'none') == 'mapping': # comes from a attemp to inject landmark condition
+ self.cond_dim = 204
+ self.cond_mapping = MappingNetwork(z_dim=self.cond_dim, c_dim=0, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ if hparams.get("gen_cond_mode", 'none') == 'mapping':
+ d_ws = self.cond_mapping(cond, 0, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ ws = ws * 0.5 + d_ws * 0.5
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
+
+
+class DiscriminatorBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0 or architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ if (x if x is not None else img).device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0 or self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+ def extra_repr(self):
+ return f'group_size={self.group_size}, num_channels={self.num_channels:d}'
+
+#----------------------------------------------------------------------------
+
+
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size = 2, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
+
+ def forward(self, x, img, cmap, force_fp32=False):
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ if self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ x = x + self.fromrgb(img)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class Discriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/modules/eg3ds/models/networks_stylegan3.py b/modules/eg3ds/models/networks_stylegan3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c38853db600f4006c3f6e0045a8df1e707ee85
--- /dev/null
+++ b/modules/eg3ds/models/networks_stylegan3.py
@@ -0,0 +1,516 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Generator architecture from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import numpy as np
+import scipy.signal
+import scipy.optimize
+import torch
+from modules.eg3ds.torch_utils import misc
+from modules.eg3ds.torch_utils.ops import conv2d_gradfix
+from modules.eg3ds.torch_utils.ops import filtered_lrelu
+from modules.eg3ds.torch_utils.ops import bias_act
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor: [batch_size, in_channels, in_height, in_width]
+ w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
+ s, # Style tensor: [batch_size, in_channels]
+ demodulate = True, # Apply weight demodulation?
+ padding = 0, # Padding: int or [padH, padW]
+ input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
+):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(x.shape[0])
+ out_channels, in_channels, kh, kw = w.shape
+ misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(s, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs.
+ if demodulate:
+ w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
+ s = s * s.square().mean().rsqrt()
+
+ # Modulate weights.
+ w = w.unsqueeze(0) # [NOIkk]
+ w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Demodulate weights.
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
+ w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Apply input scaling.
+ if input_gain is not None:
+ input_gain = input_gain.expand(batch_size, in_channels) # [NI]
+ w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Execute as one fused op using grouped convolution.
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ return x
+
+#----------------------------------------------------------------------------
+
+
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ bias = True, # Apply additive bias before the activation function?
+ lr_multiplier = 1, # Learning rate multiplier.
+ weight_init = 1, # Initial standard deviation of the weight tensor.
+ bias_init = 0, # Initial value of the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
+ self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+#----------------------------------------------------------------------------
+
+
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output.
+ num_layers = 2, # Number of mapping layers.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ # Construct layers.
+ self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
+ features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
+ for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
+ layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ misc.assert_shape(z, [None, self.z_dim])
+ if truncation_cutoff is None:
+ truncation_cutoff = self.num_ws
+
+ # Embed, normalize, and concatenate inputs.
+ x = z.to(torch.float32)
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = self.embed(c.to(torch.float32))
+ y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Execute layers.
+ for idx in range(self.num_layers):
+ x = getattr(self, f'fc{idx}')(x)
+
+ # Update moving average of W.
+ if update_emas:
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast and apply truncation.
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+ if truncation_psi != 1:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisInput(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ channels, # Number of output channels.
+ size, # Output spatial size: int or [width, height].
+ sampling_rate, # Output sampling rate.
+ bandwidth, # Output bandwidth.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.channels = channels
+ self.size = np.broadcast_to(np.asarray(size), [2])
+ self.sampling_rate = sampling_rate
+ self.bandwidth = bandwidth
+
+ # Draw random frequencies from uniform 2D disc.
+ freqs = torch.randn([self.channels, 2])
+ radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
+ freqs /= radii * radii.square().exp().pow(0.25)
+ freqs *= bandwidth
+ phases = torch.rand([self.channels]) - 0.5
+
+ # Setup parameters and buffers.
+ self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
+ self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
+ self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
+ self.register_buffer('freqs', freqs)
+ self.register_buffer('phases', phases)
+
+ def forward(self, w):
+ # Introduce batch dimension.
+ transforms = self.transform.unsqueeze(0) # [batch, row, col]
+ freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
+ phases = self.phases.unsqueeze(0) # [batch, channel]
+
+ # Apply learned transformation.
+ t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
+ t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
+ m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
+ m_r[:, 0, 0] = t[:, 0] # r'_c
+ m_r[:, 0, 1] = -t[:, 1] # r'_s
+ m_r[:, 1, 0] = t[:, 1] # r'_s
+ m_r[:, 1, 1] = t[:, 0] # r'_c
+ m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
+ m_t[:, 0, 2] = -t[:, 2] # t'_x
+ m_t[:, 1, 2] = -t[:, 3] # t'_y
+ transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
+
+ # Transform frequencies.
+ phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
+ freqs = freqs @ transforms[:, :2, :2]
+
+ # Dampen out-of-band frequencies that may occur due to the user-specified transform.
+ amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
+
+ # Construct sampling grid.
+ theta = torch.eye(2, 3, device=w.device)
+ theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
+ theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
+ grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
+
+ # Compute Fourier features.
+ x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
+ x = x + phases.unsqueeze(1).unsqueeze(2)
+ x = torch.sin(x * (np.pi * 2))
+ x = x * amplitudes.unsqueeze(1).unsqueeze(2)
+
+ # Apply trainable mapping.
+ weight = self.weight / np.sqrt(self.channels)
+ x = x @ weight.t()
+
+ # Ensure correct shape.
+ x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
+ misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
+ f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ is_torgb, # Is this the final ToRGB layer?
+ is_critically_sampled, # Does this layer use critical sampling?
+ use_fp16, # Does this layer use FP16?
+
+ # Input & output specifications.
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ in_size, # Input spatial size: int or [width, height].
+ out_size, # Output spatial size: int or [width, height].
+ in_sampling_rate, # Input sampling rate (s).
+ out_sampling_rate, # Output sampling rate (s).
+ in_cutoff, # Input cutoff frequency (f_c).
+ out_cutoff, # Output cutoff frequency (f_c).
+ in_half_width, # Input transition band half-width (f_h).
+ out_half_width, # Output Transition band half-width (f_h).
+
+ # Hyperparameters.
+ conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
+ filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
+ lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
+ use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
+ conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
+ magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.is_torgb = is_torgb
+ self.is_critically_sampled = is_critically_sampled
+ self.use_fp16 = use_fp16
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.in_size = np.broadcast_to(np.asarray(in_size), [2])
+ self.out_size = np.broadcast_to(np.asarray(out_size), [2])
+ self.in_sampling_rate = in_sampling_rate
+ self.out_sampling_rate = out_sampling_rate
+ self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
+ self.in_cutoff = in_cutoff
+ self.out_cutoff = out_cutoff
+ self.in_half_width = in_half_width
+ self.out_half_width = out_half_width
+ self.conv_kernel = 1 if is_torgb else conv_kernel
+ self.conv_clamp = conv_clamp
+ self.magnitude_ema_beta = magnitude_ema_beta
+
+ # Setup parameters and buffers.
+ self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
+ self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
+ self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
+ self.register_buffer('magnitude_ema', torch.ones([]))
+
+ # Design upsampling filter.
+ self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
+ assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
+ self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
+ self.register_buffer('up_filter', self.design_lowpass_filter(
+ numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
+
+ # Design downsampling filter.
+ self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
+ assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
+ self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
+ self.down_radial = use_radial_filters and not self.is_critically_sampled
+ self.register_buffer('down_filter', self.design_lowpass_filter(
+ numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
+
+ # Compute padding.
+ pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
+ pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
+ pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
+ pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
+ pad_hi = pad_total - pad_lo
+ self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
+
+ def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
+ assert noise_mode in ['random', 'const', 'none'] # unused
+ misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
+ misc.assert_shape(w, [x.shape[0], self.w_dim])
+
+ # Track input magnitude.
+ if update_emas:
+ with torch.autograd.profiler.record_function('update_magnitude_ema'):
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
+ self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
+ input_gain = self.magnitude_ema.rsqrt()
+
+ # Execute affine layer.
+ styles = self.affine(w)
+ if self.is_torgb:
+ weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
+ styles = styles * weight_gain
+
+ # Execute modulated conv2d.
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
+ x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
+ padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
+
+ # Execute bias, filtered leaky ReLU, and clamping.
+ gain = 1 if self.is_torgb else np.sqrt(2)
+ slope = 1 if self.is_torgb else 0.2
+ x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
+ up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
+ assert x.dtype == dtype
+ return x
+
+ @staticmethod
+ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
+ assert numtaps >= 1
+
+ # Identity filter.
+ if numtaps == 1:
+ return None
+
+ # Separable Kaiser low-pass filter.
+ if not radial:
+ f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ # Radially symmetric jinc-based filter.
+ x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
+ r = np.hypot(*np.meshgrid(x, x))
+ f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
+ beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
+ w = np.kaiser(numtaps, beta)
+ f *= np.outer(w, w)
+ f /= np.sum(f)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
+ f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
+ f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
+ f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
+ f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
+ f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
+ num_critical = 2, # Number of critically sampled layers at the end.
+ first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
+ first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
+ last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
+ margin_size = 10, # Number of additional pixels outside the image.
+ output_scale = 0.25, # Scale factor for the output image.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.num_ws = num_layers + 2
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.num_layers = num_layers
+ self.num_critical = num_critical
+ self.margin_size = margin_size
+ self.output_scale = output_scale
+ self.num_fp16_res = num_fp16_res
+
+ # Geometric progression of layer cutoffs and min. stopbands.
+ last_cutoff = self.img_resolution / 2 # f_{c,N}
+ last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
+ exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
+ cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
+ stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
+
+ # Compute remaining layer parameters.
+ sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
+ half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
+ sizes = sampling_rates + self.margin_size * 2
+ sizes[-2:] = self.img_resolution
+ channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
+ channels[-1] = self.img_channels
+
+ # Construct layers.
+ self.input = SynthesisInput(
+ w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
+ sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
+ self.layer_names = []
+ for idx in range(self.num_layers + 1):
+ prev = max(idx - 1, 0)
+ is_torgb = (idx == self.num_layers)
+ is_critically_sampled = (idx >= self.num_layers - self.num_critical)
+ use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
+ layer = SynthesisLayer(
+ w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
+ in_channels=int(channels[prev]), out_channels= int(channels[idx]),
+ in_size=int(sizes[prev]), out_size=int(sizes[idx]),
+ in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
+ in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
+ in_half_width=half_widths[prev], out_half_width=half_widths[idx],
+ **layer_kwargs)
+ name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
+ setattr(self, name, layer)
+ self.layer_names.append(name)
+
+ def forward(self, ws, **layer_kwargs):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32).unbind(dim=1)
+
+ # Execute layers.
+ x = self.input(ws[0])
+ for name, w in zip(self.layer_names, ws[1:]):
+ x = getattr(self, name)(x, w, **layer_kwargs)
+ if self.output_scale != 1:
+ x = x * self.output_scale
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
+ x = x.to(torch.float32)
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
+ f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/superresolution.py b/modules/eg3ds/models/superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb1bf50ae0a3153600c297090a053b3d5f5111e1
--- /dev/null
+++ b/modules/eg3ds/models/superresolution.py
@@ -0,0 +1,360 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Superresolution network architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.eg3ds.models.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.torch_utils import misc
+
+from modules.eg3ds.models.networks_stylegan2 import SynthesisBlock
+from modules.eg3ds.models.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer
+from utils.commons.hparams import hparams
+
+
+#----------------------------------------------------------------------------
+
+# for 512x512 generation
+class SuperresolutionHybrid8X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlock(channels, 128, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# for 256x256 generation
+
+class SuperresolutionHybrid4X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 256
+ use_fp16 = sr_num_fp16_res > 0
+ self.sr_antialias = sr_antialias
+ self.input_resolution = 128
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] < self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# for 128 x 128 generation
+
+class SuperresolutionHybrid2X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 128
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 64
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=64,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=128,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# TODO: Delete (here for backwards compatibility with old 256x256 models)
+
+class SuperresolutionHybridDeepfp32(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 256
+ use_fp16 = sr_num_fp16_res > 0
+
+ self.input_resolution = 128
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] < self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisBlockNoUp(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ # if img is not None:
+ # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ # img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+# for 512x512 generation
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.act = nn.ReLU(inplace=False)
+ # self.act = nn.LeakyReLU(inplace=False) # run3
+ # self.norm1 = nn.BatchNorm2d(in_features, affine=True)
+ # self.norm2 = nn.BatchNorm2d(in_features, affine=True)
+
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.act(out)
+ out = self.conv2(out)
+ out = self.act(out)
+ out = out + x
+ return out
+
+ # def forward(self, x):
+ # out = self.norm1(x)
+ # out = F.relu(out)
+ # out = self.conv1(out)
+ # out = self.norm2(out)
+ # out = F.relu(out)
+ # out = self.conv2(out)
+ # out = x + out
+ # return out
+
+
+class LargeSynthesisBlock0(nn.Module):
+ def __init__(self, channels, use_fp16, **block_kwargs):
+ super().__init__()
+ self.block = SynthesisBlock(channels, 256, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.resblocks = nn.Sequential(*[
+ ResBlock2d(256, kernel_size=3, padding=1) for _ in range(hparams['resblocks_in_large_sr'])
+ ])
+ self.to_rgb = nn.Conv2d(256, 3, kernel_size=1)
+
+ def forward(self, x, rgb, ws, **block_kwargs):
+ x, rgb = self.block(x, rgb, ws, **block_kwargs)
+ x = self.resblocks(x)
+ rgb = rgb + self.to_rgb(x)
+ return x, rgb
+
+class LargeSynthesisBlock1(nn.Module):
+ def __init__(self, use_fp16, **block_kwargs):
+ super().__init__()
+ self.block = SynthesisBlock(256, 128, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.resblocks = nn.Sequential(*[
+ ResBlock2d(128, kernel_size=3, padding=1) for _ in range(hparams['resblocks_in_large_sr'])
+ ])
+ self.to_rgb = nn.Conv2d(128, 3, kernel_size=1)
+
+ def forward(self, x, rgb, ws, **block_kwargs):
+ x, rgb = self.block(x, rgb, ws, **block_kwargs)
+ x = self.resblocks(x)
+ rgb = rgb + self.to_rgb(x)
+ return x, rgb
+
+class SuperresolutionHybrid8XDC(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, large_sr=False, **block_kwargs):
+ super().__init__()
+ assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ if large_sr is True:
+ self.block0 = LargeSynthesisBlock0(channels, use_fp16=sr_num_fp16_res > 0, **block_kwargs)
+ self.block1 = LargeSynthesisBlock1(use_fp16=sr_num_fp16_res > 0, **block_kwargs)
+ else:
+ self.block0 = SynthesisBlock(channels, 256, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(256, 128, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/modules/eg3ds/torch_utils/__init__.py b/modules/eg3ds/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/torch_utils/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/torch_utils/custom_ops.py b/modules/eg3ds/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2524f47ab3d5b8750cfb868cc14012f424acc8
--- /dev/null
+++ b/modules/eg3ds/torch_utils/custom_ops.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/misc.py b/modules/eg3ds/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d56d3d55fda85709ed63716485c7d55514bd1c
--- /dev/null
+++ b/modules/eg3ds/torch_utils/misc.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+from modules.eg3ds import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/__init__.py b/modules/eg3ds/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.cpp b/modules/eg3ds/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f6d0caaf4f84b94851d223e384344e1109cdc
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,103 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.cu b/modules/eg3ds/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..71ca3900deda41e62d80044f0e409875f4c794b5
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.cu
@@ -0,0 +1,177 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.h b/modules/eg3ds/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..8994bfb4e9cae790865348e08de5f685152d3344
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.h
@@ -0,0 +1,42 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.py b/modules/eg3ds/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..3984639c54faae2233837175ccb210a63016426c
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.py
@@ -0,0 +1,211 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+from modules.eg3ds import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py b/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a177cc1c0b6eabf16908cf9afaa4387e7716b72
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,199 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, weight):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
+
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/conv2d_resample.py b/modules/eg3ds/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..5daad2efadcd79513aaf8aee9ecb08a5ce04797e
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ if isinstance(f, torch.Tensor) and f.dtype == torch.float16:
+ f = f.float()
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4f55466235a020b0f5e150350bfdcd8b2a1e579d
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,304 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..aaac95408365f023ffaa4cb89348d499d3b948f0
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1288 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+
+ if (!downInline)
+ {
+ // Write into temporary buffer.
+ s_tileUpXY[dst] = v.x;
+ if (relUpY0 < tileUpH - 1)
+ s_tileUpXY[dst + tileUpW] = v.y;
+ }
+ else
+ {
+ // Write directly into output buffer.
+ if ((uint32_t)x < p.yShape.x)
+ {
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+ index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut;
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+ }
+ }
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
+ {
+ // Full upsampling filter.
+
+ if (up == 2)
+ {
+ // 2 x 2-wide.
+ __syncthreads();
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+ int src0 = relInX0 + tileInW * relInY0;
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+ #define X_LOOP(TAPY, PX) \
+ for (int sx = 0; sx < fuSize / up; sx++) \
+ { \
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ }
+
+ vec4_t v = InternalType::zero_vec4();
+ if (tap0y == 0 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 0) }
+ if (tap0y == 0 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 1) }
+ if (tap0y == 1 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 0) }
+ if (tap0y == 1 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 1) }
+
+ #undef X_LOOP
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read sign and apply.
+ {
+ if ((uint32_t)signY < p.sShape.y)
+ {
+ int s = 0;
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+ s >>= (signX & 3) << 1;
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[idx + 0] = v.x;
+ s_tileUpXY[idx + 1] = v.y;
+ s_tileUpXY[idx + 2] = v.z;
+ s_tileUpXY[idx + 3] = v.w;
+ }
+ }
+ else if (up == 1)
+ {
+ __syncthreads();
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ v *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write sign.
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ }
+ else
+ {
+ // Determine and write sign.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ else
+ {
+ // Just compute the value.
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ }
+ }
+ else if (signRead)
+ {
+ // Read sign and apply if within sign tensor bounds.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
+ {
+ int s = p.s[si];
+ s >>= signXo;
+ if (s & 1) v *= p.slope;
+ if (s & 2) v = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+
+ if (!downInline) // Write into temporary buffer.
+ s_tileUpXY[idx] = v;
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
+ *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+
+ // Downsampling.
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
+ {
+ // Horizontal downsampling.
+ __syncthreads();
+ if (down == 4 && tileOutW % 4 == 0)
+ {
+ // Calculate 4 pixels at a time.
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ s_tileDownX[idx+2] = v.z;
+ s_tileDownX[idx+3] = v.w;
+ }
+ }
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
+ {
+ // Calculate 2 pixels at a time.
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ }
+ }
+ else
+ {
+ // Calculate 1 pixel at a time.
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src = relUpY * tileUpW + relUpX0;
+ scalar_t v = 0.f;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+ s_tileDownX[idx] = v;
+ }
+ }
+
+ // Vertical downsampling & store output tile.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX, relOutY0;
+ fast_div_mod(relOutX, relOutY0, idx);
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileOutW + relOutX;
+ scalar_t v = 0;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY0;
+
+ if (outX < p.yShape.x & outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
+ {
+ // Full downsampling filter.
+ if (down == 2)
+ {
+ // 2-wide.
+ __syncthreads();
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ int relUpX0 = relOutX0 * down;
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int sy = 0; sy < fdSize; sy++)
+ #pragma unroll
+ for (int sx = 0; sx < fdSize; sx++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ }
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outY < p.yShape.y)
+ {
+ index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut;
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y;
+ }
+ }
+ }
+ else if (down == 1 && !downInline)
+ {
+ // Thread per pixel.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ }
+
+ if (!enableXrep)
+ break;
+ }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
+// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
+
+template
+static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Indexing.
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
+
+ // Loop to accommodate oversized tensors.
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
+ {
+ // Extract z and w (channel, minibatch index).
+ int32_t w = q / p.xShape.z;
+ int32_t z = q - w * p.xShape.z;
+
+ // Choose behavior based on sign read/write mode.
+ if (signWrite)
+ {
+ // Process value if in p.x.
+ uint32_t s = 0;
+ if (x < p.xShape.x && y < p.xShape.y)
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+
+ // Gain, LReLU, clamp.
+ v *= p.gain;
+ if (v < 0.f)
+ {
+ v *= p.slope;
+ s = 1; // Sign.
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ v = InternalType::clamp(v, p.clamp);
+ s = 2; // Clamp.
+ }
+
+ *pv = (T)v; // Write value.
+ }
+
+ // Coalesce into threads 0 and 16 of warp.
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
+ s |= __shfl_xor_sync(m, s, 2);
+ s |= __shfl_xor_sync(m, s, 4);
+ s |= __shfl_xor_sync(m, s, 8);
+
+ // Write signs if leader and in p.s.
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
+ {
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
+ ((uint32_t*)p.s)[is >> 4] = s;
+ }
+ }
+ else if (signRead)
+ {
+ // Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+
+ // Apply sign buffer offset.
+ uint32_t sx = x + p.sOfs.x;
+ uint32_t sy = y + p.sOfs.y;
+
+ // Read and apply signs if we land inside valid region of sign buffer.
+ if (sx < p.sShape.x && sy < p.sShape.y)
+ {
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
+ unsigned char s = p.s[is];
+ s >>= (sx & 3) << 1; // Shift into place.
+ if (s & 1) // Sign?
+ v *= p.slope;
+ if (s & 2) // Clamp?
+ v = 0.f;
+ }
+
+ *pv = (T)v; // Write value.
+ }
+ }
+ else
+ {
+ // Forward pass with no sign write. Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+ if (v < 0.f)
+ v *= p.slope;
+ if (fabsf(v) > p.clamp)
+ v = InternalType::clamp(v, p.clamp);
+ *pv = (T)v; // Write value.
+ }
+ }
+ }
+}
+
+template void* choose_filtered_lrelu_act_kernel(void)
+{
+ return (void*)filtered_lrelu_act_kernel;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
+{
+ filtered_lrelu_kernel_spec s = { 0 };
+
+ // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
+ if (sharedKB >= SH) \
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
+ { \
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
+ s.setup = (void*)setup_filters_kernel; \
+ s.exec = (void*)filtered_lrelu_kernel; \
+ s.tileOut = make_int2(TW, TH); \
+ s.numWarps = W; \
+ s.xrep = XR; \
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
+ return s; \
+ }
+
+ // Launch parameters for various kernel specializations.
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
+
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
+
+ #undef CASE
+ return s; // No kernel found.
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.h b/modules/eg3ds/torch_utils/ops/filtered_lrelu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfd1dd537909de9cd3b14765a482056391683b
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.h
@@ -0,0 +1,94 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct filtered_lrelu_kernel_params
+{
+ // These parameters decide which kernel to use.
+ int up; // upsampling ratio (1, 2, 4)
+ int down; // downsampling ratio (1, 2, 4)
+ int2 fuShape; // [size, 1] | [size, size]
+ int2 fdShape; // [size, 1] | [size, size]
+
+ int _dummy; // Alignment.
+
+ // Rest of the parameters.
+ const void* x; // Input tensor.
+ void* y; // Output tensor.
+ const void* b; // Bias tensor.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+ const float* fu; // Upsampling filter.
+ const float* fd; // Downsampling filter.
+
+ int2 pad0; // Left/top padding.
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+ int flip; // Filter kernel flip for gradient computation.
+
+ int tilesXdim; // Original number of horizontal output tiles.
+ int tilesXrep; // Number of horizontal tiles per CTA.
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
+
+ int4 xShape; // [width, height, channel, batch]
+ int4 yShape; // [width, height, channel, batch]
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+ int swLimit; // Active width of sign tensor in bytes.
+
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
+ longlong4 yStride; //
+ int64_t bStride; //
+ longlong3 fuStride; //
+ longlong3 fdStride; //
+};
+
+struct filtered_lrelu_act_kernel_params
+{
+ void* x; // Input/output, modified in-place.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+
+ int4 xShape; // [width, height, channel, batch]
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct filtered_lrelu_kernel_spec
+{
+ void* setup; // Function for filter kernel setup.
+ void* exec; // Function for main operation.
+ int2 tileOut; // Width/height of launch tile.
+ int numWarps; // Number of warps per thread block, determines launch block size.
+ int xrep; // For processing multiple horizontal tiles per thread block.
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template void* choose_filtered_lrelu_act_kernel(void);
+template cudaError_t copy_filters(cudaStream_t stream);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.py b/modules/eg3ds/torch_utils/ops/filtered_lrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2047b7e19320e8d03e444ca1cb03fe00d0c5e96e
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import numpy as np
+import torch
+import warnings
+
+from .. import custom_ops
+from .. import misc
+from . import upfirdn2d
+from . import bias_act
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='filtered_lrelu_plugin',
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor)
+ assert 1 <= f.ndim <= 2
+ return f.shape[-1], f.shape[0] # width, height
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
+ padding = [int(x) for x in padding]
+ if len(padding) == 2:
+ px, py = padding
+ padding = [px, px, py, py]
+ px0, px1, py0, py1 = padding
+ return px0, px1, py0, py1
+
+#----------------------------------------------------------------------------
+
+def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
+ r"""Filtered leaky ReLU for a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Add channel-specific bias if provided (`b`).
+
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 5. Multiply each value by the provided gain factor (`gain`).
+
+ 6. Apply leaky ReLU activation function to each value.
+
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
+
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
+ it so that the footprint of all output pixels lies within the input image.
+
+ 9. Downsample the image by keeping every Nth pixel (`down`).
+
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float16/float64 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ fu: Float32 upsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ fd: Float32 downsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The length of vector must must match the channel dimension of `x`.
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor. (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
+ flip_filter: False = convolution, True = correlation (default: False).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
+ existing `upfirdn2n()` and `bias_act()` ops.
+ """
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ fu_w, fu_h = _get_filter_size(fu)
+ fd_w, fd_h = _get_filter_size(fd)
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
+ misc.assert_shape(b, [x.shape[1]])
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ assert slope == float(slope) and slope >= 0
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+
+ # Calculate output size.
+ batch_size, channels, in_h, in_w = x.shape
+ in_dtype = x.dtype
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
+
+ # Compute using existing ops.
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Check output shape & dtype.
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
+ assert x.dtype == in_dtype
+ return x
+
+#----------------------------------------------------------------------------
+
+_filtered_lrelu_cuda_cache = dict()
+
+def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
+ """
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ gain = float(gain)
+ assert slope == float(slope) and slope >= 0
+ slope = float(slope)
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+ clamp = float(clamp if clamp is not None else 'inf')
+
+ # Lookup from cache.
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
+ if key in _filtered_lrelu_cuda_cache:
+ return _filtered_lrelu_cuda_cache[key]
+
+ # Forward op.
+ class FilteredLReluCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
+ if fu is None:
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ if fd is None:
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert 1 <= fu.ndim <= 2
+ assert 1 <= fd.ndim <= 2
+
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
+ fu = fu.square()[None]
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
+ fd = fd.square()[None]
+
+ # Missing sign input tensor.
+ if si is None:
+ si = torch.empty([0])
+
+ # Missing bias tensor.
+ if b is None:
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
+
+ # Construct internal sign tensor only if gradients are needed.
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
+
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
+
+ # Call C++/Cuda plugin if datatype is supported.
+ if x.dtype in [torch.float16, torch.float32]:
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
+ else:
+ return_code = -1
+
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
+ # only the bit-packed sign tensor is retained for gradient computation.
+ if return_code < 0:
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
+
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Prepare for gradient computation.
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
+ ctx.x_shape = x.shape
+ ctx.y_shape = y.shape
+ ctx.s_ofs = sx, sy
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ fu, fd, si = ctx.saved_tensors
+ _, _, xh, xw = ctx.x_shape
+ _, _, yh, yw = ctx.y_shape
+ sx, sy = ctx.s_ofs
+ dx = None # 0
+ dfu = None; assert not ctx.needs_input_grad[1]
+ dfd = None; assert not ctx.needs_input_grad[2]
+ db = None # 3
+ dsi = None; assert not ctx.needs_input_grad[4]
+ dsx = None; assert not ctx.needs_input_grad[5]
+ dsy = None; assert not ctx.needs_input_grad[6]
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
+ pp = [
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
+ xw * up - yw * down + px0 - (up - 1),
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
+ xh * up - yh * down + py0 - (up - 1),
+ ]
+ gg = gain * (up ** 2) / (down ** 2)
+ ff = (not flip_filter)
+ sx = sx - (fu.shape[-1] - 1) + px0
+ sy = sy - (fu.shape[0] - 1) + py0
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
+
+ if ctx.needs_input_grad[3]:
+ db = dx.sum([0, 2, 3])
+
+ return dx, dfu, dfd, db, dsi, dsx, dsy
+
+ # Add to cache.
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
+ return FilteredLReluCuda
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8a3eae46215c3babea2c54e3ae255b05f4d777af
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for no signs mode (no gradients required).
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3cd43ec0648d3db05e5808299fc0ee318e5ceaa6
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign read mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu_wr.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu_wr.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bc2fa06912eb703dd77ca64533208428bdf373ac
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu_wr.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign write mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/modules/eg3ds/torch_utils/ops/fma.py b/modules/eg3ds/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..5458116d0b6f8b133608456bbe9003aa0283ac85
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/fma.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/grid_sample_gradfix.py b/modules/eg3ds/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d94724136ba162d8416803b1ad00d6da0db99f
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ return enabled
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/upfirdn2d.cpp b/modules/eg3ds/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c1769c3cbe4dd04f76f9ccef726680720e6f39c8
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,111 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/upfirdn2d.cu b/modules/eg3ds/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7d182d7b86a9058d0c007b13716d6e7f08207f42
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,388 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ // No up/downsampling.
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+
+ // 2x upsampling.
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small