Spaces:
Running on Zero
Running on Zero
Jack Wu commited on
Commit ·
c60109f
1
Parent(s): 6cf4573
This view is limited to 50 files because it contains too many changes. See raw diff
- .idea/.gitignore +10 -0
- .idea/Generate_Audio_for_Video.iml +14 -0
- .idea/inspectionProfiles/Project_Default.xml +7 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- HunyuanVideo-Foley/.gitattributes +3 -0
- HunyuanVideo-Foley/.gitignore +159 -0
- HunyuanVideo-Foley/.pre-commit-config.yaml +38 -0
- HunyuanVideo-Foley/DEVELOPMENT.md +187 -0
- HunyuanVideo-Foley/INSTALL.md +203 -0
- HunyuanVideo-Foley/LICENSE +77 -0
- HunyuanVideo-Foley/MANIFEST.in +38 -0
- HunyuanVideo-Foley/NOTICE +27 -0
- HunyuanVideo-Foley/README.md +519 -0
- HunyuanVideo-Foley/build_package.sh +58 -0
- HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml +48 -0
- HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml +48 -0
- HunyuanVideo-Foley/download_test_videos.sh +11 -0
- HunyuanVideo-Foley/gradio_app.py +834 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/__init__.py +30 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/cli.py +141 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/constants.py +57 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/__init__.py +0 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/encode.py +94 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/hifi_foley.py +794 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/__init__.py +0 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/activation_layers.py +44 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/attn_layers.py +546 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/embed_layers.py +136 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/mlp_layers.py +149 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/modulate_layers.py +49 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/norm_layers.py +70 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/posemb_layers.py +159 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/__init__.py +1 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/ast_model.py +289 -0
- HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/compute_desync_score.py +214 -0
.idea/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Ignored default folder with query files
|
| 5 |
+
/queries/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
| 9 |
+
# Editor-based HTTP Client requests
|
| 10 |
+
/httpRequests/
|
.idea/Generate_Audio_for_Video.iml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$">
|
| 5 |
+
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
| 6 |
+
</content>
|
| 7 |
+
<orderEntry type="inheritedJdk" />
|
| 8 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 9 |
+
</component>
|
| 10 |
+
<component name="PyDocumentationSettings">
|
| 11 |
+
<option name="format" value="PLAIN" />
|
| 12 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 13 |
+
</component>
|
| 14 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="SqlNoDataSourceInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
| 5 |
+
<inspection_tool class="TodoComment" enabled="false" level="INFORMATION" enabled_by_default="false" />
|
| 6 |
+
</profile>
|
| 7 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Generate_Audio_for_Video.iml" filepath="$PROJECT_DIR$/.idea/Generate_Audio_for_Video.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
HunyuanVideo-Foley/.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
assets/model_arch.png filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
HunyuanVideo-Foley/.gitignore
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
# ==========================================
|
| 132 |
+
# Custom settings
|
| 133 |
+
# ==========================================
|
| 134 |
+
|
| 135 |
+
# For MacOS
|
| 136 |
+
.DS_Store
|
| 137 |
+
|
| 138 |
+
# For IDEs
|
| 139 |
+
.idea/
|
| 140 |
+
.vscode/
|
| 141 |
+
pyrightconfig.json
|
| 142 |
+
.cursorignore
|
| 143 |
+
|
| 144 |
+
assets/
|
| 145 |
+
examples/
|
| 146 |
+
|
| 147 |
+
# For global settings
|
| 148 |
+
__*/
|
| 149 |
+
**/my_*
|
| 150 |
+
tmp*.*
|
| 151 |
+
.my*
|
| 152 |
+
# Model checkpoints
|
| 153 |
+
*.pt
|
| 154 |
+
*.ckpt
|
| 155 |
+
*.pth
|
| 156 |
+
*.safetensors
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
CLAUDE.md
|
HunyuanVideo-Foley/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.4.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: check-yaml
|
| 8 |
+
- id: check-added-large-files
|
| 9 |
+
- id: check-merge-conflict
|
| 10 |
+
- id: debug-statements
|
| 11 |
+
- id: check-docstring-first
|
| 12 |
+
|
| 13 |
+
- repo: https://github.com/psf/black
|
| 14 |
+
rev: 23.3.0
|
| 15 |
+
hooks:
|
| 16 |
+
- id: black
|
| 17 |
+
language_version: python3
|
| 18 |
+
args: [--line-length=120]
|
| 19 |
+
|
| 20 |
+
- repo: https://github.com/pycqa/isort
|
| 21 |
+
rev: 5.12.0
|
| 22 |
+
hooks:
|
| 23 |
+
- id: isort
|
| 24 |
+
args: [--profile, black, --line-length=120]
|
| 25 |
+
|
| 26 |
+
- repo: https://github.com/pycqa/flake8
|
| 27 |
+
rev: 6.0.0
|
| 28 |
+
hooks:
|
| 29 |
+
- id: flake8
|
| 30 |
+
args: [--max-line-length=120]
|
| 31 |
+
additional_dependencies: [flake8-docstrings]
|
| 32 |
+
|
| 33 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 34 |
+
rev: v1.3.0
|
| 35 |
+
hooks:
|
| 36 |
+
- id: mypy
|
| 37 |
+
additional_dependencies: [types-all]
|
| 38 |
+
args: [--ignore-missing-imports]
|
HunyuanVideo-Foley/DEVELOPMENT.md
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Development Guide
|
| 2 |
+
|
| 3 |
+
This document provides guidelines for developing and contributing to the HunyuanVideo-Foley project.
|
| 4 |
+
|
| 5 |
+
## Code Style and Quality
|
| 6 |
+
|
| 7 |
+
### Code Formatting
|
| 8 |
+
|
| 9 |
+
We use the following tools to maintain consistent code style:
|
| 10 |
+
|
| 11 |
+
- **Black**: Code formatter with 120 character line length
|
| 12 |
+
- **isort**: Import sorter compatible with Black
|
| 13 |
+
- **flake8**: Linting and style checking
|
| 14 |
+
- **mypy**: Static type checking
|
| 15 |
+
|
| 16 |
+
### Pre-commit Hooks
|
| 17 |
+
|
| 18 |
+
Install pre-commit hooks to automatically format code before commits:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
pip install pre-commit
|
| 22 |
+
pre-commit install
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Manual Code Formatting
|
| 26 |
+
|
| 27 |
+
Format code manually:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# Format all Python files
|
| 31 |
+
black --line-length 120 .
|
| 32 |
+
|
| 33 |
+
# Sort imports
|
| 34 |
+
isort --profile black --line-length 120 .
|
| 35 |
+
|
| 36 |
+
# Check code style
|
| 37 |
+
flake8 --max-line-length 120
|
| 38 |
+
|
| 39 |
+
# Type checking
|
| 40 |
+
mypy --ignore-missing-imports .
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Project Structure
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
hunyuanvideo_foley/
|
| 47 |
+
├── models/ # Model implementations
|
| 48 |
+
│ ├── hifi_foley.py # Main model
|
| 49 |
+
│ ├── nn/ # Neural network layers
|
| 50 |
+
│ ├── dac_vae/ # Audio VAE
|
| 51 |
+
│ └── synchformer/ # Synchronization model
|
| 52 |
+
├── utils/ # Utilities
|
| 53 |
+
│ ├── config_utils.py # Configuration handling
|
| 54 |
+
│ ├── feature_utils.py # Feature extraction
|
| 55 |
+
│ ├── model_utils.py # Model loading/saving
|
| 56 |
+
│ └── media_utils.py # Audio/video processing
|
| 57 |
+
└── constants.py # Project constants
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Coding Standards
|
| 61 |
+
|
| 62 |
+
### Error Handling
|
| 63 |
+
|
| 64 |
+
- Use custom exceptions for domain-specific errors
|
| 65 |
+
- Always validate inputs at function boundaries
|
| 66 |
+
- Log errors with appropriate levels (ERROR, WARNING, INFO)
|
| 67 |
+
- Provide helpful error messages to users
|
| 68 |
+
|
| 69 |
+
### Type Hints
|
| 70 |
+
|
| 71 |
+
- Add type hints to all function parameters and return values
|
| 72 |
+
- Use `Optional[Type]` for nullable parameters
|
| 73 |
+
- Import types from `typing` module
|
| 74 |
+
|
| 75 |
+
### Documentation
|
| 76 |
+
|
| 77 |
+
- Add docstrings to all public functions and classes
|
| 78 |
+
- Use Google-style docstrings
|
| 79 |
+
- Document parameters, return values, and exceptions
|
| 80 |
+
|
| 81 |
+
### Example Function
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
def process_video(
|
| 85 |
+
video_path: str,
|
| 86 |
+
max_duration: Optional[float] = None
|
| 87 |
+
) -> Tuple[np.ndarray, float]:
|
| 88 |
+
"""
|
| 89 |
+
Process video file and extract frames.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
video_path: Path to input video file
|
| 93 |
+
max_duration: Maximum duration in seconds (optional)
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Tuple of (frames array, duration in seconds)
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
FileNotFoundError: If video file doesn't exist
|
| 100 |
+
VideoProcessingError: If video processing fails
|
| 101 |
+
"""
|
| 102 |
+
if not os.path.exists(video_path):
|
| 103 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
| 104 |
+
|
| 105 |
+
# Implementation here...
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Testing
|
| 109 |
+
|
| 110 |
+
### Running Tests
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
# Run all tests
|
| 114 |
+
python -m pytest
|
| 115 |
+
|
| 116 |
+
# Run specific test file
|
| 117 |
+
python -m pytest tests/test_feature_utils.py
|
| 118 |
+
|
| 119 |
+
# Run with coverage
|
| 120 |
+
python -m pytest --cov=hunyuanvideo_foley
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Writing Tests
|
| 124 |
+
|
| 125 |
+
- Place tests in `tests/` directory
|
| 126 |
+
- Name test files as `test_*.py`
|
| 127 |
+
- Use descriptive test function names
|
| 128 |
+
- Test edge cases and error conditions
|
| 129 |
+
|
| 130 |
+
## Development Workflow
|
| 131 |
+
|
| 132 |
+
1. **Setup Environment**
|
| 133 |
+
```bash
|
| 134 |
+
python -m venv venv
|
| 135 |
+
source venv/bin/activate # Linux/Mac
|
| 136 |
+
# or
|
| 137 |
+
venv\Scripts\activate # Windows
|
| 138 |
+
|
| 139 |
+
pip install -r requirements.txt
|
| 140 |
+
pip install -e .
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
2. **Install Development Tools**
|
| 144 |
+
```bash
|
| 145 |
+
pre-commit install
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
3. **Make Changes**
|
| 149 |
+
- Follow the coding standards above
|
| 150 |
+
- Add tests for new functionality
|
| 151 |
+
- Update documentation as needed
|
| 152 |
+
|
| 153 |
+
4. **Run Quality Checks**
|
| 154 |
+
```bash
|
| 155 |
+
black --check --line-length 120 .
|
| 156 |
+
isort --check-only --profile black .
|
| 157 |
+
flake8 --max-line-length 120
|
| 158 |
+
mypy --ignore-missing-imports .
|
| 159 |
+
pytest
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
5. **Commit Changes**
|
| 163 |
+
```bash
|
| 164 |
+
git add .
|
| 165 |
+
git commit -m "feat: add new feature"
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## Performance Considerations
|
| 169 |
+
|
| 170 |
+
- Use `torch.no_grad()` for inference-only code
|
| 171 |
+
- Leverage GPU when available
|
| 172 |
+
- Implement batch processing where possible
|
| 173 |
+
- Profile code to identify bottlenecks
|
| 174 |
+
|
| 175 |
+
## Dependencies
|
| 176 |
+
|
| 177 |
+
- Keep dependencies minimal and well-maintained
|
| 178 |
+
- Pin versions for reproducibility
|
| 179 |
+
- Separate development dependencies from runtime dependencies
|
| 180 |
+
- Document any special installation requirements
|
| 181 |
+
|
| 182 |
+
## Configuration
|
| 183 |
+
|
| 184 |
+
- Use centralized configuration in `constants.py`
|
| 185 |
+
- Support environment variable overrides
|
| 186 |
+
- Provide sensible defaults for all parameters
|
| 187 |
+
- Validate configuration at startup
|
HunyuanVideo-Foley/INSTALL.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 安装指南 - HunyuanVideo-Foley
|
| 2 |
+
|
| 3 |
+
本文档提供了将 HunyuanVideo-Foley 作为 Python 包安装和使用的详细指南。
|
| 4 |
+
|
| 5 |
+
## 安装方式
|
| 6 |
+
|
| 7 |
+
### 方式1:从源码安装(推荐)
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
# 克隆仓库
|
| 11 |
+
git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
|
| 12 |
+
cd HunyuanVideo-Foley
|
| 13 |
+
|
| 14 |
+
# 安装包(开发模式)
|
| 15 |
+
pip install -e .
|
| 16 |
+
|
| 17 |
+
# 或安装包含所有可选依赖
|
| 18 |
+
pip install -e .[all]
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### 方式2:直接从GitHub安装
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
pip install git+https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley.git
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### 方式3:构建wheel包安装
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# 在项目根目录下
|
| 31 |
+
python setup.py bdist_wheel
|
| 32 |
+
pip install dist/hunyuanvideo_foley-1.0.0-py3-none-any.whl
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## 特殊依赖安装
|
| 36 |
+
|
| 37 |
+
由于某些依赖不在PyPI上,需要单独安装:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# 安装audiotools(必需)
|
| 41 |
+
pip install git+https://github.com/descriptinc/audiotools
|
| 42 |
+
|
| 43 |
+
# 安装特定版本的transformers(支持SigLIP2)
|
| 44 |
+
pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 可选依赖安装
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# 安装开发依赖
|
| 51 |
+
pip install hunyuanvideo-foley[dev]
|
| 52 |
+
|
| 53 |
+
# 安装测试依赖
|
| 54 |
+
pip install hunyuanvideo-foley[test]
|
| 55 |
+
|
| 56 |
+
# 安装Gradio界面依赖
|
| 57 |
+
pip install hunyuanvideo-foley[gradio]
|
| 58 |
+
|
| 59 |
+
# 安装所有可选依赖
|
| 60 |
+
pip install hunyuanvideo-foley[all]
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## 验证安装
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
# 检查包是否正确安装
|
| 67 |
+
python -c "import hunyuanvideo_foley; print(hunyuanvideo_foley.__version__)"
|
| 68 |
+
|
| 69 |
+
# 检查命令行工具
|
| 70 |
+
hunyuanvideo-foley --help
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## 使用方法
|
| 74 |
+
|
| 75 |
+
### 1. 作为Python包使用
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import hunyuanvideo_foley as hvf
|
| 79 |
+
|
| 80 |
+
# 加载模型
|
| 81 |
+
model_dict, cfg = hvf.load_model(
|
| 82 |
+
model_path="path/to/model",
|
| 83 |
+
config_path="configs/hunyuanvideo-foley-xxl.yaml"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# 处理特征
|
| 87 |
+
visual_feats, text_feats, audio_len = hvf.feature_process(
|
| 88 |
+
video_path="video.mp4",
|
| 89 |
+
prompt="footsteps on gravel",
|
| 90 |
+
model_dict=model_dict,
|
| 91 |
+
cfg=cfg
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 生成音频
|
| 95 |
+
audio, sample_rate = hvf.denoise_process(
|
| 96 |
+
visual_feats, text_feats, audio_len,
|
| 97 |
+
model_dict, cfg
|
| 98 |
+
)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### 2. 使用命令行工具
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
# 单个视频处理
|
| 105 |
+
hunyuanvideo-foley \
|
| 106 |
+
--model_path ./pretrained_models \
|
| 107 |
+
--single_video video.mp4 \
|
| 108 |
+
--single_prompt "footsteps on gravel" \
|
| 109 |
+
--output_dir ./outputs
|
| 110 |
+
|
| 111 |
+
# 批量处理
|
| 112 |
+
hunyuanvideo-foley \
|
| 113 |
+
--model_path ./pretrained_models \
|
| 114 |
+
--csv_path batch_videos.csv \
|
| 115 |
+
--output_dir ./outputs
|
| 116 |
+
|
| 117 |
+
# 启动Gradio界面
|
| 118 |
+
hunyuanvideo-foley --gradio --model_path ./pretrained_models
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### 3. 使用原始脚本(向后兼容)
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
# 使用原始infer.py脚本
|
| 125 |
+
python infer.py --model_path ./pretrained_models --single_video video.mp4 --single_prompt "audio description"
|
| 126 |
+
|
| 127 |
+
# 启动Gradio应用
|
| 128 |
+
export HIFI_FOLEY_MODEL_PATH=./pretrained_models
|
| 129 |
+
python gradio_app.py
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## 开发环境设置
|
| 133 |
+
|
| 134 |
+
如果你想参与开发:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
# 克隆项目
|
| 138 |
+
git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
|
| 139 |
+
cd HunyuanVideo-Foley
|
| 140 |
+
|
| 141 |
+
# 安装开发版本
|
| 142 |
+
pip install -e .[dev]
|
| 143 |
+
|
| 144 |
+
# 安装pre-commit钩子
|
| 145 |
+
pre-commit install
|
| 146 |
+
|
| 147 |
+
# 运行测试
|
| 148 |
+
python -m pytest
|
| 149 |
+
|
| 150 |
+
# 代码格式化
|
| 151 |
+
black --line-length 120 .
|
| 152 |
+
isort --profile black .
|
| 153 |
+
|
| 154 |
+
# 类型检查
|
| 155 |
+
mypy --ignore-missing-imports .
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## 系统要求
|
| 159 |
+
|
| 160 |
+
- **Python**: 3.8+
|
| 161 |
+
- **操作系统**: Linux(主要支持),macOS,Windows
|
| 162 |
+
- **GPU内存**: 推荐 ≥24GB VRAM(如RTX 3090/4090)
|
| 163 |
+
- **CUDA版本**: 12.4 或 11.8(推荐)
|
| 164 |
+
|
| 165 |
+
## 故障排除
|
| 166 |
+
|
| 167 |
+
### 常见问题
|
| 168 |
+
|
| 169 |
+
1. **ImportError: No module named 'audiotools'**
|
| 170 |
+
```bash
|
| 171 |
+
pip install git+https://github.com/descriptinc/audiotools
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
2. **CUDA内存不足**
|
| 175 |
+
- 使用较小的批次大小
|
| 176 |
+
- 确保GPU有足够的VRAM(推荐24GB+)
|
| 177 |
+
|
| 178 |
+
3. **transformers版本问题**
|
| 179 |
+
```bash
|
| 180 |
+
pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
### 获取帮助
|
| 184 |
+
|
| 185 |
+
- 查看项目README: [GitHub](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley)
|
| 186 |
+
- 报告问题: [GitHub Issues](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley/issues)
|
| 187 |
+
- 论文: [arXiv:2508.16930](https://arxiv.org/abs/2508.16930)
|
| 188 |
+
|
| 189 |
+
## 模型下载
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
# 使用HuggingFace Hub
|
| 193 |
+
git clone https://huggingface.co/tencent/HunyuanVideo-Foley
|
| 194 |
+
|
| 195 |
+
# 或使用huggingface-cli
|
| 196 |
+
huggingface-cli download tencent/HunyuanVideo-Foley
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## 配置文件
|
| 200 |
+
|
| 201 |
+
包安装后,配置文件位于:
|
| 202 |
+
- `hunyuanvideo_foley/configs/` 目录
|
| 203 |
+
- 默认配置:`configs/hunyuanvideo-foley-xxl.yaml`
|
HunyuanVideo-Foley/LICENSE
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Tencent HunyuanVideo-Foley Release Date: August 28, 2025
|
| 3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
| 4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
| 5 |
+
1. DEFINITIONS.
|
| 6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
| 7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
| 8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
| 9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
| 10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
| 11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
| 12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
| 13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
| 14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
| 15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley].
|
| 16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
| 17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
| 18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
| 19 |
+
n. “including” shall mean including but not limited to.
|
| 20 |
+
2. GRANT OF RIGHTS.
|
| 21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
| 22 |
+
3. DISTRIBUTION.
|
| 23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
| 24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
| 25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
| 27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
| 28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
| 29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
| 30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
| 31 |
+
5. RULES OF USE.
|
| 32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
| 33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
| 34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
| 35 |
+
6. INTELLECTUAL PROPERTY.
|
| 36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
| 37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
| 38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
| 39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
| 40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
| 41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
| 42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 44 |
+
8. SURVIVAL AND TERMINATION.
|
| 45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
| 46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
| 47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
| 48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
| 49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
| 50 |
+
|
| 51 |
+
EXHIBIT A
|
| 52 |
+
ACCEPTABLE USE POLICY
|
| 53 |
+
|
| 54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
| 55 |
+
Last modified: November 5, 2024
|
| 56 |
+
|
| 57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
| 58 |
+
1. Outside the Territory;
|
| 59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
| 60 |
+
3. To harm Yourself or others;
|
| 61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
| 62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
| 63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
| 65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
| 66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
| 67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
| 68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
| 69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
| 70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
| 71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
| 72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
| 73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
| 74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
| 75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 76 |
+
19. For military purposes;
|
| 77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
HunyuanVideo-Foley/MANIFEST.in
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Include package metadata and documentation
|
| 2 |
+
include README.md
|
| 3 |
+
include LICENSE
|
| 4 |
+
include NOTICE
|
| 5 |
+
include DEVELOPMENT.md
|
| 6 |
+
include CLAUDE.md
|
| 7 |
+
include requirements.txt
|
| 8 |
+
include pyproject.toml
|
| 9 |
+
include pytest.ini
|
| 10 |
+
|
| 11 |
+
# Include configuration files
|
| 12 |
+
include configs/*.yaml
|
| 13 |
+
include configs/*.yml
|
| 14 |
+
recursive-include hunyuanvideo_foley/configs *.yaml *.yml
|
| 15 |
+
|
| 16 |
+
# Include test assets if any
|
| 17 |
+
include assets/*.csv
|
| 18 |
+
include assets/*.txt
|
| 19 |
+
recursive-include assets/test_videos *
|
| 20 |
+
|
| 21 |
+
# Include example scripts
|
| 22 |
+
include *.py
|
| 23 |
+
include *.sh
|
| 24 |
+
|
| 25 |
+
# Include test files
|
| 26 |
+
recursive-include tests *.py
|
| 27 |
+
|
| 28 |
+
# Exclude unnecessary files
|
| 29 |
+
global-exclude *.pyc
|
| 30 |
+
global-exclude *.pyo
|
| 31 |
+
global-exclude *~
|
| 32 |
+
global-exclude .DS_Store
|
| 33 |
+
global-exclude __pycache__
|
| 34 |
+
prune .git
|
| 35 |
+
prune .github
|
| 36 |
+
prune examples/*/outputs
|
| 37 |
+
prune **/__pycache__
|
| 38 |
+
prune **/*.pyc
|
HunyuanVideo-Foley/NOTICE
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Usage and Legal Notices:
|
| 2 |
+
|
| 3 |
+
Tencent is pleased to support the open source community by making Tencent HunyuanVideo-Foley available.
|
| 4 |
+
|
| 5 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
| 6 |
+
|
| 7 |
+
Tencent HunyuanVideo-Foley is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent HunyuanVideo-Foley does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
For avoidance of doubts, Tencent HunyuanVideo-Foley means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Other dependencies and licenses:
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Open Source Software Licensed under the MIT License:
|
| 16 |
+
--------------------------------------------------------------------
|
| 17 |
+
1. syncformer
|
| 18 |
+
Copyright (c) 2024 Vladimir Iashin
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Terms of the MIT License:
|
| 22 |
+
--------------------------------------------------------------------
|
| 23 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 24 |
+
|
| 25 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 26 |
+
|
| 27 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
HunyuanVideo-Foley/README.md
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
|
| 4 |
+
|
| 5 |
+
<img src="assets/logo.png" alt="HunyuanVideo-Foley Logo" width="400">
|
| 6 |
+
|
| 7 |
+
<h4>Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation</h4>
|
| 8 |
+
|
| 9 |
+
<p align="center">
|
| 10 |
+
<strong>Professional-grade AI sound effect generation for video content creators</strong>
|
| 11 |
+
</p>
|
| 12 |
+
|
| 13 |
+
<div align="center">
|
| 14 |
+
<a href=https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/Code-black.svg?logo=github height=22px></a>
|
| 15 |
+
<a href=https://szczesnys.github.io/hunyuanvideo-foley target="_blank"><img src=https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
|
| 16 |
+
<a href=https://huggingface.co/tencent/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
|
| 17 |
+
<a href=https://huggingface.co/spaces/tencent/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Demo-276cb4.svg height=22px></a>
|
| 18 |
+
<a href=https://arxiv.org/abs/2508.16930 target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
|
| 19 |
+
<a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
|
| 20 |
+
<a href=https://discord.gg/YEyGGn6Bte target="_blank"><img src=https://img.shields.io/badge/Hunyuan-141984.svg?logo=discord height=22px></a>
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
<div align="center">
|
| 28 |
+
|
| 29 |
+
### 👥 **Authors**
|
| 30 |
+
|
| 31 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 15px; margin: 20px 0;">
|
| 32 |
+
|
| 33 |
+
**Sizhe Shan**<sup>1,2*</sup> • **Qiulin Li**<sup>1,3*</sup> • **Yutao Cui**<sup>1</sup> • **Miles Yang**<sup>1</sup> • **Yuehai Wang**<sup>2</sup> • **Qun Yang**<sup>3</sup> • **Jin Zhou**<sup>1†</sup> • **Zhao Zhong**<sup>1</sup>
|
| 34 |
+
|
| 35 |
+
</div>
|
| 36 |
+
|
| 37 |
+
<div style="margin-top: 15px; font-size: 14px; color: #666;">
|
| 38 |
+
|
| 39 |
+
🏢 <sup>1</sup>**Tencent Hunyuan** • 🎓 <sup>2</sup>**Zhejiang University** • ✈️ <sup>3</sup>**Nanjing University of Aeronautics and Astronautics**
|
| 40 |
+
|
| 41 |
+
*Equal contribution • †Project lead
|
| 42 |
+
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
</div>
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 🔥🔥🔥 **News**
|
| 51 |
+
|
| 52 |
+
<div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); padding: 20px; border-radius: 15px; margin: 20px 0; border-left: 5px solid #2196f3;">
|
| 53 |
+
|
| 54 |
+
- **[2025.9.29]** 🚀 **HunyuanVideo-Foley-XL Model Release** - Release XL-sized model with offload inference support, significantly reducing VRAM requirements.
|
| 55 |
+
- **[2025.8.28]** 🌟 **HunyuanVideo-Foley Open Source Release** - Inference code and model weights publicly available.
|
| 56 |
+
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## 🎥 **Demo & Showcase**
|
| 62 |
+
|
| 63 |
+
<div align="center">
|
| 64 |
+
|
| 65 |
+
> **Experience the magic of AI-generated Foley audio in perfect sync with video content!**
|
| 66 |
+
|
| 67 |
+
<div style="border: 3px solid #4A90E2; border-radius: 15px; padding: 10px; margin: 20px 0; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);">
|
| 68 |
+
|
| 69 |
+
<video src="https://github.com/user-attachments/assets/d6e1b6fd-6980-4a68-8717-74298d064195" width="80%" controls style="border-radius: 10px; box-shadow: 0 8px 32px rgba(0,0,0,0.1);"> </video>
|
| 70 |
+
|
| 71 |
+
<p><em>🎬 Watch how HunyuanVideo-Foley generates immersive sound effects synchronized with video content</em></p>
|
| 72 |
+
|
| 73 |
+
</div>
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 🤝 **Community Contributions**
|
| 78 |
+
|
| 79 |
+
<div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #28a745; margin: 20px 0; color: #333;">
|
| 80 |
+
|
| 81 |
+
**ComfyUI Integration** - Thanks to the amazing community for creating ComfyUI nodes:
|
| 82 |
+
|
| 83 |
+
- **[if-ai/ComfyUI_HunyuanVideoFoley](https://github.com/if-ai/ComfyUI_HunyuanVideoFoley)** - ComfyUI workflow integration which supports cpu offloading and FP8 quantization
|
| 84 |
+
- **[phazei/ComfyUI-HunyuanVideo-Foley](https://github.com/phazei/ComfyUI-HunyuanVideo-Foley)** - Alternative ComfyUI node implementation which supports different precision modes
|
| 85 |
+
|
| 86 |
+
</div>
|
| 87 |
+
|
| 88 |
+
<div align="center" style="margin: 20px 0;">
|
| 89 |
+
|
| 90 |
+
**🌟 We encourage and appreciate community contributions that make HunyuanVideo-Foley more accessible!**
|
| 91 |
+
|
| 92 |
+
</div>
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
### ✨ **Key Highlights**
|
| 96 |
+
|
| 97 |
+
<table align="center" style="border: none; margin: 20px 0;">
|
| 98 |
+
<tr>
|
| 99 |
+
<td align="center" width="33%">
|
| 100 |
+
|
| 101 |
+
🎭 **Multi-scenario Sync**
|
| 102 |
+
High-quality audio synchronized with complex video scenes
|
| 103 |
+
|
| 104 |
+
</td>
|
| 105 |
+
<td align="center" width="33%">
|
| 106 |
+
|
| 107 |
+
🧠 **Multi-modal Balance**
|
| 108 |
+
Perfect harmony between visual and textual information
|
| 109 |
+
|
| 110 |
+
</td>
|
| 111 |
+
<td align="center" width="33%">
|
| 112 |
+
|
| 113 |
+
🎵 **48kHz Hi-Fi Output**
|
| 114 |
+
Professional-grade audio generation with crystal clarity
|
| 115 |
+
|
| 116 |
+
</td>
|
| 117 |
+
</tr>
|
| 118 |
+
</table>
|
| 119 |
+
|
| 120 |
+
</div>
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## 📄 **Abstract**
|
| 125 |
+
|
| 126 |
+
<div align="center" style="background: linear-gradient(135deg, #ffeef8 0%, #f0f8ff 100%); padding: 30px; border-radius: 20px; margin: 20px 0; border-left: 5px solid #ff6b9d; color: #333;">
|
| 127 |
+
|
| 128 |
+
**🚀 Tencent Hunyuan** open-sources **HunyuanVideo-Foley** an end-to-end video sound effect generation model!
|
| 129 |
+
|
| 130 |
+
*A professional-grade AI tool specifically designed for video content creators, widely applicable to diverse scenarios including short video creation, film production, advertising creativity, and game development.*
|
| 131 |
+
|
| 132 |
+
</div>
|
| 133 |
+
|
| 134 |
+
### 🎯 **Core Highlights**
|
| 135 |
+
|
| 136 |
+
<div style="display: grid; grid-template-columns: 1fr; gap: 15px; margin: 20px 0;">
|
| 137 |
+
|
| 138 |
+
<div style="border-left: 4px solid #4CAF50; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
|
| 139 |
+
|
| 140 |
+
**🎬 Multi-scenario Audio-Visual Synchronization**
|
| 141 |
+
Supports generating high-quality audio that is synchronized and semantically aligned with complex video scenes, enhancing realism and immersive experience for film/TV and gaming applications.
|
| 142 |
+
|
| 143 |
+
</div>
|
| 144 |
+
|
| 145 |
+
<div style="border-left: 4px solid #2196F3; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
|
| 146 |
+
|
| 147 |
+
**⚖️ Multi-modal Semantic Balance**
|
| 148 |
+
Intelligently balances visual and textual information analysis, comprehensively orchestrates sound effect elements, avoids one-sided generation, and meets personalized dubbing requirements.
|
| 149 |
+
|
| 150 |
+
</div>
|
| 151 |
+
|
| 152 |
+
<div style="border-left: 4px solid #FF9800; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
|
| 153 |
+
|
| 154 |
+
**🎵 High-fidelity Audio Output**
|
| 155 |
+
Self-developed 48kHz audio VAE perfectly reconstructs sound effects, music, and vocals, achieving professional-grade audio generation quality.
|
| 156 |
+
|
| 157 |
+
</div>
|
| 158 |
+
|
| 159 |
+
</div>
|
| 160 |
+
|
| 161 |
+
<div align="center" style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0; color: #333;">
|
| 162 |
+
|
| 163 |
+
**🏆 SOTA Performance Achieved**
|
| 164 |
+
|
| 165 |
+
*HunyuanVideo-Foley comprehensively leads the field across multiple evaluation benchmarks, achieving new state-of-the-art levels in audio fidelity, visual-semantic alignment, temporal alignment, and distribution matching - surpassing all open-source solutions!*
|
| 166 |
+
|
| 167 |
+
</div>
|
| 168 |
+
|
| 169 |
+
<div align="center">
|
| 170 |
+
|
| 171 |
+

|
| 172 |
+
*📊 Performance comparison across different evaluation metrics - HunyuanVideo-Foley leads in all categories*
|
| 173 |
+
|
| 174 |
+
</div>
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## 🔧 **Technical Architecture**
|
| 179 |
+
|
| 180 |
+
### 📊 **Data Pipeline Design**
|
| 181 |
+
|
| 182 |
+
<div align="center" style="margin: 20px 0; color: #333;">
|
| 183 |
+
|
| 184 |
+

|
| 185 |
+
*🔄 Comprehensive data processing pipeline for high-quality text-video-audio datasets*
|
| 186 |
+
|
| 187 |
+
</div>
|
| 188 |
+
|
| 189 |
+
<div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #17a2b8; margin: 20px 0;">
|
| 190 |
+
|
| 191 |
+
The **TV2A (Text-Video-to-Audio)** task presents a complex multimodal generation challenge requiring large-scale, high-quality datasets. Our comprehensive data pipeline systematically identifies and excludes unsuitable content to produce robust and generalizable audio generation capabilities.
|
| 192 |
+
|
| 193 |
+
</div>
|
| 194 |
+
|
| 195 |
+
### 🏗️ **Model Architecture**
|
| 196 |
+
|
| 197 |
+
<div align="center" style="margin: 20px 0; color: #333;">
|
| 198 |
+
|
| 199 |
+

|
| 200 |
+
*🧠 HunyuanVideo-Foley hybrid architecture with multimodal and unimodal transformer blocks*
|
| 201 |
+
|
| 202 |
+
</div>
|
| 203 |
+
|
| 204 |
+
<div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #28a745; margin: 20px 0;">
|
| 205 |
+
|
| 206 |
+
**HunyuanVideo-Foley** employs a sophisticated hybrid architecture:
|
| 207 |
+
|
| 208 |
+
- **🔄 Multimodal Transformer Blocks**: Process visual-audio streams simultaneously
|
| 209 |
+
- **🎵 Unimodal Transformer Blocks**: Focus on audio stream refinement
|
| 210 |
+
- **👁️ Visual Encoding**: Pre-trained encoder extracts visual features from video frames
|
| 211 |
+
- **📝 Text Processing**: Semantic features extracted via pre-trained text encoder
|
| 212 |
+
- **🎧 Audio Encoding**: Latent representations with Gaussian noise perturbation
|
| 213 |
+
- **⏰ Temporal Alignment**: Synchformer-based frame-level synchronization with gated modulation
|
| 214 |
+
|
| 215 |
+
</div>
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
## 📈 **Performance Benchmarks**
|
| 220 |
+
|
| 221 |
+
### 🎬 **MovieGen-Audio-Bench Results**
|
| 222 |
+
|
| 223 |
+
<div align="center">
|
| 224 |
+
|
| 225 |
+
> *Objective and Subjective evaluation results demonstrating superior performance across all metrics*
|
| 226 |
+
|
| 227 |
+
</div>
|
| 228 |
+
|
| 229 |
+
<div style="overflow-x: auto; margin: 20px 0;">
|
| 230 |
+
|
| 231 |
+
| 🏆 **Method** | **PQ** ↑ | **PC** ↓ | **CE** ↑ | **CU** ↑ | **IB** ↑ | **DeSync** ↓ | **CLAP** ↑ | **MOS-Q** ↑ | **MOS-S** ↑ | **MOS-T** ↑ |
|
| 232 |
+
|:-------------:|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------:|:-----------:|:------------:|:------------:|:------------:|
|
| 233 |
+
| FoleyGrafter | 6.27 | 2.72 | 3.34 | 5.68 | 0.17 | 1.29 | 0.14 | 3.36±0.78 | 3.54±0.88 | 3.46±0.95 |
|
| 234 |
+
| V-AURA | 5.82 | 4.30 | 3.63 | 5.11 | 0.23 | 1.38 | 0.14 | 2.55±0.97 | 2.60±1.20 | 2.70±1.37 |
|
| 235 |
+
| Frieren | 5.71 | 2.81 | 3.47 | 5.31 | 0.18 | 1.39 | 0.16 | 2.92±0.95 | 2.76±1.20 | 2.94±1.26 |
|
| 236 |
+
| MMAudio | 6.17 | 2.84 | 3.59 | 5.62 | 0.27 | 0.80 | 0.35 | 3.58±0.84 | 3.63±1.00 | 3.47±1.03 |
|
| 237 |
+
| ThinkSound | 6.04 | 3.73 | 3.81 | 5.59 | 0.18 | 0.91 | 0.20 | 3.20±0.97 | 3.01±1.04 | 3.02±1.08 |
|
| 238 |
+
| **HunyuanVideo-Foley (ours)** | **6.59** | **2.74** | **3.88** | **6.13** | **0.35** | **0.74** | **0.33** | **4.14±0.68** | **4.12±0.77** | **4.15±0.75** |
|
| 239 |
+
|
| 240 |
+
</div>
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
### 🎯 **Kling-Audio-Eval Results**
|
| 244 |
+
|
| 245 |
+
<div align="center">
|
| 246 |
+
|
| 247 |
+
> *Comprehensive objective evaluation showcasing state-of-the-art performance*
|
| 248 |
+
|
| 249 |
+
</div>
|
| 250 |
+
|
| 251 |
+
<div style="overflow-x: auto; margin: 20px 0;">
|
| 252 |
+
|
| 253 |
+
| 🏆 **Method** | **FD_PANNs** ↓ | **FD_PASST** ↓ | **KL** ↓ | **IS** ↑ | **PQ** ↑ | **PC** ↓ | **CE** ↑ | **CU** ↑ | **IB** ↑ | **DeSync** ↓ | **CLAP** ↑ |
|
| 254 |
+
|:-------------:|:--------------:|:--------------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------:|:-----------:|
|
| 255 |
+
| FoleyGrafter | 22.30 | 322.63 | 2.47 | 7.08 | 6.05 | 2.91 | 3.28 | 5.44 | 0.22 | 1.23 | 0.22 |
|
| 256 |
+
| V-AURA | 33.15 | 474.56 | 3.24 | 5.80 | 5.69 | 3.98 | 3.13 | 4.83 | 0.25 | 0.86 | 0.13 |
|
| 257 |
+
| Frieren | 16.86 | 293.57 | 2.95 | 7.32 | 5.72 | 2.55 | 2.88 | 5.10 | 0.21 | 0.86 | 0.16 |
|
| 258 |
+
| MMAudio | 9.01 | 205.85 | 2.17 | 9.59 | 5.94 | 2.91 | 3.30 | 5.39 | 0.30 | 0.56 | 0.27 |
|
| 259 |
+
| ThinkSound | 9.92 | 228.68 | 2.39 | 6.86 | 5.78 | 3.23 | 3.12 | 5.11 | 0.22 | 0.67 | 0.22 |
|
| 260 |
+
| **HunyuanVideo-Foley (ours)** | **6.07** | **202.12** | **1.89** | **8.30** | **6.12** | **2.76** | **3.22** | **5.53** | **0.38** | **0.54** | **0.24** |
|
| 261 |
+
|
| 262 |
+
</div>
|
| 263 |
+
|
| 264 |
+
<div align="center" style="background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%); color: white; padding: 15px; border-radius: 10px; margin: 20px 0; color: #333;">
|
| 265 |
+
|
| 266 |
+
**🎉 Outstanding Results!** HunyuanVideo-Foley achieves the best scores across **ALL** evaluation metrics, demonstrating significant improvements in audio quality, synchronization, and semantic alignment.
|
| 267 |
+
|
| 268 |
+
</div>
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## 🚀 **Quick Start**
|
| 275 |
+
|
| 276 |
+
### 📦 **Installation**
|
| 277 |
+
|
| 278 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0; color: #333;">
|
| 279 |
+
|
| 280 |
+
**🔧 System Requirements**
|
| 281 |
+
- **CUDA**: 12.4 or 11.8 recommended
|
| 282 |
+
- **Python**: 3.8+
|
| 283 |
+
- **OS**: Linux (primary support)
|
| 284 |
+
- **VRAM**: 20GB for XXL model (or 12GB with `--enable_offload`), 16GB for XL model (or 8GB with `--enable_offload`)
|
| 285 |
+
|
| 286 |
+
</div>
|
| 287 |
+
|
| 288 |
+
#### **Step 1: Clone Repository**
|
| 289 |
+
|
| 290 |
+
```bash
|
| 291 |
+
# 📥 Clone the repository
|
| 292 |
+
git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
|
| 293 |
+
cd HunyuanVideo-Foley
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
#### **Step 2: Environment Setup**
|
| 297 |
+
|
| 298 |
+
<div style="background: #fff3cd; padding: 15px; border-radius: 8px; border-left: 4px solid #ffc107; margin: 10px 0; color: #333;">
|
| 299 |
+
|
| 300 |
+
💡 **Tip**: We recommend using [Conda](https://docs.anaconda.com/free/miniconda/index.html) for Python environment management.
|
| 301 |
+
|
| 302 |
+
</div>
|
| 303 |
+
|
| 304 |
+
```bash
|
| 305 |
+
# 🔧 Install dependencies
|
| 306 |
+
pip install -r requirements.txt
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
#### **Step 3: Download Pretrained Models**
|
| 310 |
+
|
| 311 |
+
<div style="background: #d1ecf1; padding: 15px; border-radius: 8px; border-left: 4px solid #17a2b8; margin: 10px 0;color: #333;">
|
| 312 |
+
|
| 313 |
+
🔗 **Download Model weights from Huggingface**
|
| 314 |
+
```bash
|
| 315 |
+
# using git-lfs
|
| 316 |
+
git clone https://huggingface.co/tencent/HunyuanVideo-Foley
|
| 317 |
+
|
| 318 |
+
# using huggingface-cli
|
| 319 |
+
huggingface-cli download tencent/HunyuanVideo-Foley
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
<!-- 🔗 **Download Model weights from ModelScope** -->
|
| 323 |
+
<!-- ```bash -->
|
| 324 |
+
<!-- # using git-lfs -->
|
| 325 |
+
<!-- git clone https://huggingface.co/tencent/HunyuanVideo-Foley -->
|
| 326 |
+
<!-- -->
|
| 327 |
+
<!-- # using huggingface-cli -->
|
| 328 |
+
<!-- huggingface-cli download tencent/HunyuanVideo-Foley -->
|
| 329 |
+
<!-- ``` -->
|
| 330 |
+
|
| 331 |
+
</div>
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
---
|
| 335 |
+
|
| 336 |
+
## 💻 **Usage**
|
| 337 |
+
|
| 338 |
+
### 📊 **Model Specifications**
|
| 339 |
+
|
| 340 |
+
| Model | Checkpoint | VRAM (Normal) | VRAM (Offload) |
|
| 341 |
+
|-------|------------|---------------|----------------|
|
| 342 |
+
| **XXL** *(Default)* | `hunyuanvideo_foley.pth` | 20GB | 12GB |
|
| 343 |
+
| **XL** | `hunyuanvideo_foley_xl.pth` | 16GB | 8GB |
|
| 344 |
+
|
| 345 |
+
### 🎬 **Single Video Generation**
|
| 346 |
+
|
| 347 |
+
<div style="background: #e8f5e8; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745; margin: 10px 0;color: #333;">
|
| 348 |
+
|
| 349 |
+
Generate Foley audio for a single video file with text description:
|
| 350 |
+
|
| 351 |
+
</div>
|
| 352 |
+
|
| 353 |
+
```bash
|
| 354 |
+
# Use XXL model (default, best quality)
|
| 355 |
+
python3 infer.py \
|
| 356 |
+
--model_path PRETRAINED_MODEL_PATH_DIR \
|
| 357 |
+
--single_video video_path \
|
| 358 |
+
--single_prompt "audio description" \
|
| 359 |
+
--output_dir OUTPUT_DIR \
|
| 360 |
+
# --enable_offload
|
| 361 |
+
|
| 362 |
+
# Use XL model (memory-friendly)
|
| 363 |
+
python3 infer.py \
|
| 364 |
+
--model_path PRETRAINED_MODEL_PATH_DIR \
|
| 365 |
+
--model_size xl \
|
| 366 |
+
--single_video video_path \
|
| 367 |
+
--single_prompt "audio description" \
|
| 368 |
+
--output_dir OUTPUT_DIR \
|
| 369 |
+
# --enable_offload
|
| 370 |
+
```
|
| 371 |
+
|
| 372 |
+
### 📂 **Batch Processing**
|
| 373 |
+
|
| 374 |
+
<div style="background: #fff3e0; padding: 15px; border-radius: 8px; border-left: 4px solid #ff9800; margin: 10px 0;color: #333;">
|
| 375 |
+
|
| 376 |
+
Process multiple videos using a CSV file with video paths and descriptions:
|
| 377 |
+
|
| 378 |
+
</div>
|
| 379 |
+
|
| 380 |
+
```bash
|
| 381 |
+
# Download sample test videos
|
| 382 |
+
bash ./download_test_videos.sh
|
| 383 |
+
|
| 384 |
+
# Batch processing
|
| 385 |
+
python3 infer.py \
|
| 386 |
+
--model_path PRETRAINED_MODEL_PATH_DIR \
|
| 387 |
+
--csv_path assets/test.csv \
|
| 388 |
+
--output_dir OUTPUT_DIR \
|
| 389 |
+
# --enable_offload
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
### 🌐 **Interactive Web Interface**
|
| 393 |
+
|
| 394 |
+
<div style="background: #f3e5f5; padding: 15px; border-radius: 8px; border-left: 4px solid #9c27b0; margin: 10px 0;color: #333;">
|
| 395 |
+
|
| 396 |
+
Launch a user-friendly Gradio web interface for easy interaction:
|
| 397 |
+
|
| 398 |
+
</div>
|
| 399 |
+
|
| 400 |
+
```bash
|
| 401 |
+
# Launch with XXL model (default)
|
| 402 |
+
export HIFI_FOLEY_MODEL_PATH=PRETRAINED_MODEL_PATH_DIR
|
| 403 |
+
python3 gradio_app.py
|
| 404 |
+
|
| 405 |
+
# Launch with XL model (memory-friendly)
|
| 406 |
+
export HIFI_FOLEY_MODEL_PATH=PRETRAINED_MODEL_PATH_DIR
|
| 407 |
+
MODEL_SIZE=xl python3 gradio_app.py
|
| 408 |
+
|
| 409 |
+
# Optional: Enable offload to reduce memory usage
|
| 410 |
+
ENABLE_OFFLOAD=true python3 gradio_app.py
|
| 411 |
+
```
|
| 412 |
+
|
| 413 |
+
<div align="center" style="margin: 20px 0; color: #333;">
|
| 414 |
+
|
| 415 |
+
*🚀 Then open your browser and navigate to the provided local URL to start generating Foley audio!*
|
| 416 |
+
|
| 417 |
+
</div>
|
| 418 |
+
|
| 419 |
+
---
|
| 420 |
+
|
| 421 |
+
## 📚 **Citation**
|
| 422 |
+
|
| 423 |
+
<div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #6c757d; margin: 20px 0; color: #333;">
|
| 424 |
+
|
| 425 |
+
If you find **HunyuanVideo-Foley** useful for your research, please consider citing our paper:
|
| 426 |
+
|
| 427 |
+
</div>
|
| 428 |
+
|
| 429 |
+
```bibtex
|
| 430 |
+
@misc{shan2025hunyuanvideofoleymultimodaldiffusionrepresentation,
|
| 431 |
+
title={HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation},
|
| 432 |
+
author={Sizhe Shan and Qiulin Li and Yutao Cui and Miles Yang and Yuehai Wang and Qun Yang and Jin Zhou and Zhao Zhong},
|
| 433 |
+
year={2025},
|
| 434 |
+
eprint={2508.16930},
|
| 435 |
+
archivePrefix={arXiv},
|
| 436 |
+
primaryClass={eess.AS},
|
| 437 |
+
url={https://arxiv.org/abs/2508.16930},
|
| 438 |
+
}
|
| 439 |
+
```
|
| 440 |
+
## Star History
|
| 441 |
+
|
| 442 |
+
[](https://www.star-history.com/#Tencent-Hunyuan/HunyuanVideo-Foley&Date)
|
| 443 |
+
---
|
| 444 |
+
|
| 445 |
+
## 🙏 **Acknowledgements**
|
| 446 |
+
|
| 447 |
+
<div align="center">
|
| 448 |
+
|
| 449 |
+
**We extend our heartfelt gratitude to the open-source community!**
|
| 450 |
+
|
| 451 |
+
</div>
|
| 452 |
+
|
| 453 |
+
<table align="center" style="width: 100%; border: none; margin: 20px 0;">
|
| 454 |
+
<tr>
|
| 455 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 456 |
+
|
| 457 |
+
🎨 **[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)**
|
| 458 |
+
*Foundation diffusion models*
|
| 459 |
+
|
| 460 |
+
</td>
|
| 461 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 462 |
+
|
| 463 |
+
⚡ **[FLUX](https://github.com/black-forest-labs/flux)**
|
| 464 |
+
*Advanced generation techniques*
|
| 465 |
+
|
| 466 |
+
</td>
|
| 467 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 468 |
+
|
| 469 |
+
🎵 **[MMAudio](https://github.com/hkchengrex/MMAudio)**
|
| 470 |
+
*Multimodal audio generation*
|
| 471 |
+
|
| 472 |
+
</td>
|
| 473 |
+
</tr>
|
| 474 |
+
<tr>
|
| 475 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 476 |
+
|
| 477 |
+
🤗 **[HuggingFace](https://huggingface.co)**
|
| 478 |
+
*Platform & diffusers library*
|
| 479 |
+
|
| 480 |
+
</td>
|
| 481 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 482 |
+
|
| 483 |
+
🗜️ **[DAC](https://github.com/descriptinc/descript-audio-codec)**
|
| 484 |
+
*High-Fidelity Audio Compression*
|
| 485 |
+
|
| 486 |
+
</td>
|
| 487 |
+
<td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
|
| 488 |
+
|
| 489 |
+
🔗 **[Synchformer](https://github.com/v-iashin/Synchformer)**
|
| 490 |
+
*Audio-Visual Synchronization*
|
| 491 |
+
|
| 492 |
+
</td>
|
| 493 |
+
</tr>
|
| 494 |
+
</table>
|
| 495 |
+
|
| 496 |
+
<div align="center" style="background: linear-gradient(135deg, #74b9ff 0%, #0984e3 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0;, color: #333;">
|
| 497 |
+
|
| 498 |
+
**🌟 Special thanks to all researchers and developers who contribute to the advancement of AI-generated audio and multimodal learning!**
|
| 499 |
+
|
| 500 |
+
</div>
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
---
|
| 504 |
+
|
| 505 |
+
<div align="center" style="margin: 30px 0;">
|
| 506 |
+
|
| 507 |
+
### 🔗 **Connect with Us**
|
| 508 |
+
|
| 509 |
+
[](https://github.com/Tencent-Hunyuan)
|
| 510 |
+
[](https://twitter.com/Tencent)
|
| 511 |
+
[](https://hunyuan.tencent.com/)
|
| 512 |
+
|
| 513 |
+
<p style="color: #666; margin-top: 15px; font-size: 14px;">
|
| 514 |
+
|
| 515 |
+
© 2025 Tencent Hunyuan. All rights reserved. | Made with ❤️ for the AI community
|
| 516 |
+
|
| 517 |
+
</p>
|
| 518 |
+
|
| 519 |
+
</div>
|
HunyuanVideo-Foley/build_package.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 构建 HunyuanVideo-Foley Python 包的脚本
|
| 3 |
+
|
| 4 |
+
set -e # 出现错误时退出
|
| 5 |
+
|
| 6 |
+
echo "🚀 开始构建 HunyuanVideo-Foley Python 包..."
|
| 7 |
+
|
| 8 |
+
# 清理之前的构建文件
|
| 9 |
+
echo "🧹 清理之前的构建文件..."
|
| 10 |
+
rm -rf build/ dist/ *.egg-info/
|
| 11 |
+
|
| 12 |
+
# 检查必要的工具
|
| 13 |
+
echo "🔍 检查构建工具..."
|
| 14 |
+
python -c "import setuptools, wheel; print('✅ setuptools和wheel已安装')" || {
|
| 15 |
+
echo "❌ 请安装构建工具: pip install setuptools wheel"
|
| 16 |
+
exit 1
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# 检查setup.py
|
| 20 |
+
echo "🔍 验证setup.py配置..."
|
| 21 |
+
python setup.py check --restructuredtext --strict || {
|
| 22 |
+
echo "⚠️ setup.py验证有警告,但继续构建..."
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# 构建源码分发包
|
| 26 |
+
echo "📦 构建源码分发包..."
|
| 27 |
+
python setup.py sdist
|
| 28 |
+
|
| 29 |
+
# 构建wheel包
|
| 30 |
+
echo "🎡 构建wheel包..."
|
| 31 |
+
python setup.py bdist_wheel
|
| 32 |
+
|
| 33 |
+
# 显示构建结果
|
| 34 |
+
echo "✅ 构建完成!生成的包:"
|
| 35 |
+
ls -la dist/
|
| 36 |
+
|
| 37 |
+
# 验证包
|
| 38 |
+
echo "🔍 验证生成的包..."
|
| 39 |
+
python -m pip check dist/*.whl || echo "⚠️ 包验证有警告"
|
| 40 |
+
|
| 41 |
+
echo ""
|
| 42 |
+
echo "📝 安装说明:"
|
| 43 |
+
echo "# 从wheel文件安装:"
|
| 44 |
+
echo "pip install dist/hunyuanvideo_foley-1.0.0-py3-none-any.whl"
|
| 45 |
+
echo ""
|
| 46 |
+
echo "# 开发模式安装:"
|
| 47 |
+
echo "pip install -e ."
|
| 48 |
+
echo ""
|
| 49 |
+
echo "# 安装所有可选依赖:"
|
| 50 |
+
echo "pip install -e .[all]"
|
| 51 |
+
echo ""
|
| 52 |
+
|
| 53 |
+
echo "⚠️ 注意:某些依赖需要单独安装:"
|
| 54 |
+
echo "pip install git+https://github.com/descriptinc/audiotools"
|
| 55 |
+
echo "pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2"
|
| 56 |
+
|
| 57 |
+
echo ""
|
| 58 |
+
echo "🎉 构建完成!查看 INSTALL.md 获取详细安装指南。"
|
HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_config:
|
| 2 |
+
model_name: HunyuanVideo-Foley-XL
|
| 3 |
+
model_type: 1d
|
| 4 |
+
model_precision: bf16
|
| 5 |
+
model_kwargs:
|
| 6 |
+
depth_triple_blocks: 12
|
| 7 |
+
depth_single_blocks: 24
|
| 8 |
+
hidden_size: 1408
|
| 9 |
+
num_heads: 11
|
| 10 |
+
mlp_ratio: 4
|
| 11 |
+
mlp_act_type: "gelu_tanh"
|
| 12 |
+
qkv_bias: True
|
| 13 |
+
qk_norm: True
|
| 14 |
+
qk_norm_type: "rms"
|
| 15 |
+
attn_mode: "torch"
|
| 16 |
+
embedder_type: "default"
|
| 17 |
+
interleaved_audio_visual_rope: True
|
| 18 |
+
enable_learnable_empty_visual_feat: True
|
| 19 |
+
sync_modulation: False
|
| 20 |
+
add_sync_feat_to_audio: True
|
| 21 |
+
cross_attention: True
|
| 22 |
+
use_attention_mask: False
|
| 23 |
+
condition_projection: "linear"
|
| 24 |
+
sync_feat_dim: 768 # syncformer 768 dim
|
| 25 |
+
condition_dim: 768 # clap 768 text condition dim (clip-text)
|
| 26 |
+
clip_dim: 768 # siglip2 visual dim
|
| 27 |
+
audio_vae_latent_dim: 128
|
| 28 |
+
audio_frame_rate: 50
|
| 29 |
+
patch_size: 1
|
| 30 |
+
rope_dim_list: null
|
| 31 |
+
rope_theta: 10000
|
| 32 |
+
text_length: 77
|
| 33 |
+
clip_length: 64
|
| 34 |
+
sync_length: 192
|
| 35 |
+
depth_triple_ssl_encoder: null
|
| 36 |
+
depth_single_ssl_encoder: 8
|
| 37 |
+
use_repa_with_audiossl: True
|
| 38 |
+
|
| 39 |
+
diffusion_config:
|
| 40 |
+
denoise_type: "flow"
|
| 41 |
+
flow_path_type: "linear"
|
| 42 |
+
flow_predict_type: "velocity"
|
| 43 |
+
flow_reverse: True
|
| 44 |
+
flow_solver: "euler"
|
| 45 |
+
sample_flow_shift: 1.0
|
| 46 |
+
sample_use_flux_shift: False
|
| 47 |
+
flux_base_shift: 0.5
|
| 48 |
+
flux_max_shift: 1.15
|
HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_config:
|
| 2 |
+
model_name: HunyuanVideo-Foley-XXL
|
| 3 |
+
model_type: 1d
|
| 4 |
+
model_precision: bf16
|
| 5 |
+
model_kwargs:
|
| 6 |
+
depth_triple_blocks: 18
|
| 7 |
+
depth_single_blocks: 36
|
| 8 |
+
hidden_size: 1536
|
| 9 |
+
num_heads: 12
|
| 10 |
+
mlp_ratio: 4
|
| 11 |
+
mlp_act_type: "gelu_tanh"
|
| 12 |
+
qkv_bias: True
|
| 13 |
+
qk_norm: True
|
| 14 |
+
qk_norm_type: "rms"
|
| 15 |
+
attn_mode: "torch"
|
| 16 |
+
embedder_type: "default"
|
| 17 |
+
interleaved_audio_visual_rope: True
|
| 18 |
+
enable_learnable_empty_visual_feat: True
|
| 19 |
+
sync_modulation: False
|
| 20 |
+
add_sync_feat_to_audio: True
|
| 21 |
+
cross_attention: True
|
| 22 |
+
use_attention_mask: False
|
| 23 |
+
condition_projection: "linear"
|
| 24 |
+
sync_feat_dim: 768 # syncformer 768 dim
|
| 25 |
+
condition_dim: 768 # clap 768 text condition dim (clip-text)
|
| 26 |
+
clip_dim: 768 # siglip2 visual dim
|
| 27 |
+
audio_vae_latent_dim: 128
|
| 28 |
+
audio_frame_rate: 50
|
| 29 |
+
patch_size: 1
|
| 30 |
+
rope_dim_list: null
|
| 31 |
+
rope_theta: 10000
|
| 32 |
+
text_length: 77
|
| 33 |
+
clip_length: 64
|
| 34 |
+
sync_length: 192
|
| 35 |
+
depth_triple_ssl_encoder: null
|
| 36 |
+
depth_single_ssl_encoder: 8
|
| 37 |
+
use_repa_with_audiossl: True
|
| 38 |
+
|
| 39 |
+
diffusion_config:
|
| 40 |
+
denoise_type: "flow"
|
| 41 |
+
flow_path_type: "linear"
|
| 42 |
+
flow_predict_type: "velocity"
|
| 43 |
+
flow_reverse: True
|
| 44 |
+
flow_solver: "euler"
|
| 45 |
+
sample_flow_shift: 1.0
|
| 46 |
+
sample_use_flux_shift: False
|
| 47 |
+
flux_base_shift: 0.5
|
| 48 |
+
flux_max_shift: 1.15
|
HunyuanVideo-Foley/download_test_videos.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Download MoviegenAudioBenchSfx 10 videos
|
| 4 |
+
curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/MovieGenAudioBenchSfx.tar.gz
|
| 5 |
+
tar -xzvf MovieGenAudioBenchSfx.tar.gz -C ./assets
|
| 6 |
+
rm MovieGenAudioBenchSfx.tar.gz
|
| 7 |
+
|
| 8 |
+
# Download gradio example video
|
| 9 |
+
curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/examples.tar.gz
|
| 10 |
+
tar -xvzf examples.tar.gz
|
| 11 |
+
rm examples.tar.gz
|
HunyuanVideo-Foley/gradio_app.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from hunyuanvideo_foley.utils.model_utils import load_model
|
| 12 |
+
from hunyuanvideo_foley.utils.feature_utils import feature_process
|
| 13 |
+
from hunyuanvideo_foley.utils.model_utils import denoise_process
|
| 14 |
+
from hunyuanvideo_foley.utils.media_utils import merge_audio_video
|
| 15 |
+
|
| 16 |
+
# Global variables for model storage
|
| 17 |
+
model_dict = None
|
| 18 |
+
cfg = None
|
| 19 |
+
device = None
|
| 20 |
+
|
| 21 |
+
# need to modify the model path
|
| 22 |
+
MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
|
| 23 |
+
ENABLE_OFFLOAD = os.environ.get("ENABLE_OFFLOAD", "false").lower() in ("true", "1", "yes")
|
| 24 |
+
MODEL_SIZE = os.environ.get("MODEL_SIZE", "xxl") # default to xxl model
|
| 25 |
+
CONFIG_PATH = os.environ.get("CONFIG_PATH", "")
|
| 26 |
+
|
| 27 |
+
def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device:
|
| 28 |
+
"""Setup computing device"""
|
| 29 |
+
if device_str == "auto":
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 32 |
+
logger.info(f"Using CUDA device: {device}")
|
| 33 |
+
elif torch.backends.mps.is_available():
|
| 34 |
+
device = torch.device("mps")
|
| 35 |
+
logger.info("Using MPS device")
|
| 36 |
+
else:
|
| 37 |
+
device = torch.device("cpu")
|
| 38 |
+
logger.info("Using CPU device")
|
| 39 |
+
else:
|
| 40 |
+
if device_str == "cuda":
|
| 41 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 42 |
+
else:
|
| 43 |
+
device = torch.device(device_str)
|
| 44 |
+
logger.info(f"Using specified device: {device}")
|
| 45 |
+
|
| 46 |
+
return device
|
| 47 |
+
|
| 48 |
+
def auto_load_models() -> str:
|
| 49 |
+
"""Automatically load preset models"""
|
| 50 |
+
global model_dict, cfg, device
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
if not os.path.exists(MODEL_PATH):
|
| 54 |
+
return f"❌ Model directory not found: {MODEL_PATH}"
|
| 55 |
+
|
| 56 |
+
# Use GPU by default
|
| 57 |
+
device = setup_device("auto", 0)
|
| 58 |
+
|
| 59 |
+
# Auto-select config if not specified
|
| 60 |
+
config_path = CONFIG_PATH
|
| 61 |
+
if not config_path:
|
| 62 |
+
config_mapping = {
|
| 63 |
+
"xl": "configs/hunyuanvideo-foley-xl.yaml",
|
| 64 |
+
"xxl": "configs/hunyuanvideo-foley-xxl.yaml"
|
| 65 |
+
}
|
| 66 |
+
config_path = config_mapping.get(MODEL_SIZE, "configs/hunyuanvideo-foley-xxl.yaml")
|
| 67 |
+
|
| 68 |
+
# Load model
|
| 69 |
+
logger.info("Auto-loading model...")
|
| 70 |
+
logger.info(f"Model path: {MODEL_PATH}")
|
| 71 |
+
logger.info(f"Model size: {MODEL_SIZE}")
|
| 72 |
+
logger.info(f"Config path: {config_path}")
|
| 73 |
+
logger.info(f"Offload mode: {'enabled' if ENABLE_OFFLOAD else 'disabled'}")
|
| 74 |
+
|
| 75 |
+
model_dict, cfg = load_model(MODEL_PATH, config_path, device, enable_offload=ENABLE_OFFLOAD, model_size=MODEL_SIZE)
|
| 76 |
+
|
| 77 |
+
logger.info("✅ Model loaded successfully!")
|
| 78 |
+
return "✅ Model loaded successfully!"
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Model loading failed: {str(e)}")
|
| 82 |
+
return f"❌ Model loading failed: {str(e)}"
|
| 83 |
+
|
| 84 |
+
def infer_single_video(
|
| 85 |
+
video_file,
|
| 86 |
+
text_prompt: str,
|
| 87 |
+
neg_prompt: str = None,
|
| 88 |
+
guidance_scale: float = 4.5,
|
| 89 |
+
num_inference_steps: int = 50,
|
| 90 |
+
sample_nums: int = 1
|
| 91 |
+
) -> Tuple[list, str]:
|
| 92 |
+
"""Single video inference"""
|
| 93 |
+
global model_dict, cfg, device
|
| 94 |
+
|
| 95 |
+
if model_dict is None or cfg is None:
|
| 96 |
+
return [], "❌ Please load the model first!"
|
| 97 |
+
|
| 98 |
+
if video_file is None:
|
| 99 |
+
return [], "❌ Please upload a video file!"
|
| 100 |
+
|
| 101 |
+
# Allow empty text prompt, use empty string if no prompt provided
|
| 102 |
+
if text_prompt is None:
|
| 103 |
+
text_prompt = ""
|
| 104 |
+
text_prompt = text_prompt.strip()
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
logger.info(f"Processing video: {video_file}")
|
| 108 |
+
logger.info(f"Text prompt: {text_prompt}")
|
| 109 |
+
|
| 110 |
+
# Feature processing
|
| 111 |
+
visual_feats, text_feats, audio_len_in_s = feature_process(
|
| 112 |
+
video_file,
|
| 113 |
+
text_prompt,
|
| 114 |
+
model_dict,
|
| 115 |
+
cfg,
|
| 116 |
+
neg_prompt=neg_prompt
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Denoising process to generate multiple audio samples
|
| 120 |
+
# Note: The model now generates sample_nums audio samples per inference
|
| 121 |
+
# The denoise_process function returns audio with shape [batch_size, channels, samples]
|
| 122 |
+
logger.info(f"Generating {sample_nums} audio samples...")
|
| 123 |
+
audio, sample_rate = denoise_process(
|
| 124 |
+
visual_feats,
|
| 125 |
+
text_feats,
|
| 126 |
+
audio_len_in_s,
|
| 127 |
+
model_dict,
|
| 128 |
+
cfg,
|
| 129 |
+
guidance_scale=guidance_scale,
|
| 130 |
+
num_inference_steps=num_inference_steps,
|
| 131 |
+
batch_size=sample_nums
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Create temporary files to save results
|
| 135 |
+
temp_dir = tempfile.mkdtemp()
|
| 136 |
+
video_outputs = []
|
| 137 |
+
|
| 138 |
+
# Process each generated audio sample
|
| 139 |
+
for i in range(sample_nums):
|
| 140 |
+
# Save audio file
|
| 141 |
+
audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
|
| 142 |
+
torchaudio.save(audio_output, audio[i], sample_rate)
|
| 143 |
+
|
| 144 |
+
# Merge video and audio
|
| 145 |
+
video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
|
| 146 |
+
merge_audio_video(audio_output, video_file, video_output)
|
| 147 |
+
video_outputs.append(video_output)
|
| 148 |
+
|
| 149 |
+
logger.info(f"Inference completed! Generated {sample_nums} samples.")
|
| 150 |
+
return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!"
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Inference failed: {str(e)}")
|
| 154 |
+
return [], f"❌ Inference failed: {str(e)}"
|
| 155 |
+
|
| 156 |
+
def update_video_outputs(video_list, status_msg):
|
| 157 |
+
"""Update video outputs based on the number of generated samples"""
|
| 158 |
+
# Initialize all outputs as None
|
| 159 |
+
outputs = [None] * 6
|
| 160 |
+
|
| 161 |
+
# Set values based on generated videos
|
| 162 |
+
for i, video_path in enumerate(video_list[:6]): # Max 6 samples
|
| 163 |
+
outputs[i] = video_path
|
| 164 |
+
|
| 165 |
+
# Return all outputs plus status message
|
| 166 |
+
return tuple(outputs + [status_msg])
|
| 167 |
+
|
| 168 |
+
def create_gradio_interface():
|
| 169 |
+
"""Create Gradio interface"""
|
| 170 |
+
|
| 171 |
+
# Custom CSS for beautiful interface with better contrast
|
| 172 |
+
css = """
|
| 173 |
+
.gradio-container {
|
| 174 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
| 175 |
+
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
| 176 |
+
min-height: 100vh;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
.main-header {
|
| 180 |
+
text-align: center;
|
| 181 |
+
padding: 2rem 0;
|
| 182 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 183 |
+
border-radius: 20px;
|
| 184 |
+
margin-bottom: 2rem;
|
| 185 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.15);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
.main-header h1 {
|
| 189 |
+
color: white;
|
| 190 |
+
font-size: 3rem;
|
| 191 |
+
font-weight: 700;
|
| 192 |
+
margin-bottom: 0.5rem;
|
| 193 |
+
text-shadow: 0 2px 10px rgba(0,0,0,0.3);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
.main-header p {
|
| 197 |
+
color: rgba(255, 255, 255, 0.95);
|
| 198 |
+
font-size: 1.2rem;
|
| 199 |
+
font-weight: 300;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
.status-card {
|
| 203 |
+
background: white;
|
| 204 |
+
border-radius: 15px;
|
| 205 |
+
padding: 1rem;
|
| 206 |
+
margin-bottom: 1.5rem;
|
| 207 |
+
border: 1px solid #e1e5e9;
|
| 208 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
.status-card label {
|
| 212 |
+
color: #2d3748 !important;
|
| 213 |
+
font-weight: 600 !important;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.usage-guide h3 {
|
| 217 |
+
color: #2d3748 !important;
|
| 218 |
+
font-weight: 600 !important;
|
| 219 |
+
margin-bottom: 0.5rem !important;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
.usage-guide p {
|
| 223 |
+
color: #4a5568 !important;
|
| 224 |
+
font-size: 1rem !important;
|
| 225 |
+
line-height: 1.6 !important;
|
| 226 |
+
margin: 0.5rem 0 !important;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
.usage-guide strong {
|
| 230 |
+
color: #1a202c !important;
|
| 231 |
+
font-weight: 700 !important;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
.usage-guide em {
|
| 235 |
+
color: #1a202c !important;
|
| 236 |
+
font-weight: 700 !important;
|
| 237 |
+
font-style: normal !important;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
.main-interface {
|
| 241 |
+
margin-bottom: 2rem;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.input-section {
|
| 245 |
+
background: white;
|
| 246 |
+
border-radius: 20px;
|
| 247 |
+
padding: 2rem;
|
| 248 |
+
margin-right: 1rem;
|
| 249 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
| 250 |
+
border: 1px solid #e1e5e9;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
.input-section h3 {
|
| 254 |
+
color: #2d3748 !important;
|
| 255 |
+
font-weight: 600 !important;
|
| 256 |
+
margin-bottom: 1rem !important;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
.input-section label {
|
| 260 |
+
color: #4a5568 !important;
|
| 261 |
+
font-weight: 500 !important;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.output-section {
|
| 265 |
+
background: white;
|
| 266 |
+
border-radius: 20px;
|
| 267 |
+
padding: 2rem;
|
| 268 |
+
margin-left: 1rem;
|
| 269 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
| 270 |
+
border: 1px solid #e1e5e9;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.output-section h3 {
|
| 274 |
+
color: #2d3748 !important;
|
| 275 |
+
font-weight: 600 !important;
|
| 276 |
+
margin-bottom: 1rem !important;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
.output-section label {
|
| 280 |
+
color: #4a5568 !important;
|
| 281 |
+
font-weight: 500 !important;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.examples-section h3 {
|
| 285 |
+
color: #2d3748 !important;
|
| 286 |
+
font-weight: 600 !important;
|
| 287 |
+
margin-bottom: 1.5rem !important;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.generate-btn {
|
| 291 |
+
background: linear-gradient(45deg, #667eea, #764ba2) !important;
|
| 292 |
+
border: none !important;
|
| 293 |
+
color: white !important;
|
| 294 |
+
font-weight: 600 !important;
|
| 295 |
+
font-size: 1.1rem !important;
|
| 296 |
+
padding: 12px 30px !important;
|
| 297 |
+
border-radius: 25px !important;
|
| 298 |
+
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
|
| 299 |
+
transition: all 0.3s ease !important;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
.generate-btn:hover {
|
| 303 |
+
transform: translateY(-2px) !important;
|
| 304 |
+
box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
.examples-section {
|
| 310 |
+
background: white;
|
| 311 |
+
border-radius: 20px;
|
| 312 |
+
padding: 2rem;
|
| 313 |
+
margin-top: 2rem;
|
| 314 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
|
| 315 |
+
border: 1px solid #e1e5e9;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
.examples-section p {
|
| 319 |
+
color: #4a5568 !important;
|
| 320 |
+
margin-bottom: 1rem !important;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
.example-row {
|
| 324 |
+
background: #f8fafc;
|
| 325 |
+
border: 1px solid #e2e8f0;
|
| 326 |
+
border-radius: 15px;
|
| 327 |
+
padding: 1.5rem;
|
| 328 |
+
margin: 1rem 0;
|
| 329 |
+
transition: all 0.3s ease;
|
| 330 |
+
align-items: center;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
.example-row:hover {
|
| 334 |
+
border-color: #667eea;
|
| 335 |
+
transform: translateY(-2px);
|
| 336 |
+
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
.example-row .markdown {
|
| 340 |
+
color: #2d3748 !important;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.example-row .markdown p {
|
| 344 |
+
color: #2d3748 !important;
|
| 345 |
+
margin: 0.5rem 0 !important;
|
| 346 |
+
line-height: 1.5 !important;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
.example-row .markdown strong {
|
| 350 |
+
color: #1a202c !important;
|
| 351 |
+
font-weight: 600 !important;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/* Example grid layout styles */
|
| 355 |
+
.example-grid-row {
|
| 356 |
+
margin: 1rem 0;
|
| 357 |
+
gap: 1rem;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
.example-item {
|
| 361 |
+
background: #f8fafc;
|
| 362 |
+
border: 1px solid #e2e8f0;
|
| 363 |
+
border-radius: 15px;
|
| 364 |
+
padding: 1rem;
|
| 365 |
+
transition: all 0.3s ease;
|
| 366 |
+
margin: 0.25rem;
|
| 367 |
+
max-width: 250px;
|
| 368 |
+
margin-left: auto;
|
| 369 |
+
margin-right: auto;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
.example-item:hover {
|
| 373 |
+
border-color: #667eea;
|
| 374 |
+
transform: translateY(-2px);
|
| 375 |
+
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
.example-caption {
|
| 379 |
+
margin: 0.5rem 0 !important;
|
| 380 |
+
min-height: 2.8rem !important;
|
| 381 |
+
display: flex !important;
|
| 382 |
+
align-items: flex-start !important;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
.example-caption p {
|
| 386 |
+
color: #2d3748 !important;
|
| 387 |
+
font-size: 0.9rem !important;
|
| 388 |
+
line-height: 1.4 !important;
|
| 389 |
+
margin: 0.5rem 0 !important;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/* Multi-video gallery styles */
|
| 393 |
+
.additional-samples {
|
| 394 |
+
margin-top: 1rem;
|
| 395 |
+
gap: 0.5rem;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.additional-samples .gradio-video {
|
| 399 |
+
border-radius: 10px;
|
| 400 |
+
overflow: hidden;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
/* Video gallery responsive layout */
|
| 404 |
+
.video-gallery {
|
| 405 |
+
display: grid;
|
| 406 |
+
gap: 1rem;
|
| 407 |
+
margin-top: 1rem;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
.video-gallery.single {
|
| 411 |
+
grid-template-columns: 1fr;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
.video-gallery.dual {
|
| 415 |
+
grid-template-columns: 1fr 1fr;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
.video-gallery.multi {
|
| 419 |
+
grid-template-columns: repeat(2, 1fr);
|
| 420 |
+
grid-template-rows: auto auto auto;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
.footer-text {
|
| 424 |
+
color: #718096 !important;
|
| 425 |
+
text-align: center;
|
| 426 |
+
padding: 2rem;
|
| 427 |
+
font-size: 0.9rem;
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
/* Video component styling for consistent size */
|
| 431 |
+
.input-section video,
|
| 432 |
+
.output-section video,
|
| 433 |
+
.example-row video {
|
| 434 |
+
width: 100% !important;
|
| 435 |
+
height: 300px !important;
|
| 436 |
+
object-fit: contain !important;
|
| 437 |
+
border-radius: 10px !important;
|
| 438 |
+
background-color: #000 !important;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
.example-row video {
|
| 442 |
+
height: 150px !important;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
/* Fix for additional samples video display */
|
| 446 |
+
.additional-samples video {
|
| 447 |
+
height: 150px !important;
|
| 448 |
+
object-fit: contain !important;
|
| 449 |
+
border-radius: 10px !important;
|
| 450 |
+
background-color: #000 !important;
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
.additional-samples .gradio-video {
|
| 454 |
+
border-radius: 10px !important;
|
| 455 |
+
overflow: hidden !important;
|
| 456 |
+
background-color: #000 !important;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
.additional-samples .gradio-video > div {
|
| 460 |
+
background-color: #000 !important;
|
| 461 |
+
border-radius: 10px !important;
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
/* Video container styling */
|
| 465 |
+
.input-section .video-container,
|
| 466 |
+
.output-section .video-container,
|
| 467 |
+
.example-row .video-container {
|
| 468 |
+
background-color: #000 !important;
|
| 469 |
+
border-radius: 10px !important;
|
| 470 |
+
display: flex !important;
|
| 471 |
+
align-items: center !important;
|
| 472 |
+
justify-content: center !important;
|
| 473 |
+
overflow: hidden !important;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
/* Ensure proper alignment */
|
| 477 |
+
.example-row {
|
| 478 |
+
display: flex !important;
|
| 479 |
+
align-items: stretch !important;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
.example-row > div {
|
| 483 |
+
display: flex !important;
|
| 484 |
+
flex-direction: column !important;
|
| 485 |
+
justify-content: center !important;
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
/* Video wrapper for better control */
|
| 489 |
+
.video-wrapper {
|
| 490 |
+
position: relative !important;
|
| 491 |
+
width: 100% !important;
|
| 492 |
+
background: #000 !important;
|
| 493 |
+
border-radius: 10px !important;
|
| 494 |
+
overflow: hidden !important;
|
| 495 |
+
display: flex !important;
|
| 496 |
+
align-items: center !important;
|
| 497 |
+
justify-content: center !important;
|
| 498 |
+
}
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app:
|
| 502 |
+
|
| 503 |
+
# Main header
|
| 504 |
+
with gr.Column(elem_classes=["main-header"]):
|
| 505 |
+
gr.HTML("""
|
| 506 |
+
<h1>🎵 HunyuanVideo-Foley</h1>
|
| 507 |
+
<p>Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions</p>
|
| 508 |
+
""")
|
| 509 |
+
|
| 510 |
+
# Usage Guide
|
| 511 |
+
with gr.Column(elem_classes=["status-card"]):
|
| 512 |
+
gr.Markdown("""
|
| 513 |
+
### 📋 Quick Start Guide
|
| 514 |
+
**1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio
|
| 515 |
+
|
| 516 |
+
💡 For quick start, you can load the prepared examples by clicking the button.
|
| 517 |
+
""", elem_classes=["usage-guide"])
|
| 518 |
+
|
| 519 |
+
# Main inference interface - Input and Results side by side
|
| 520 |
+
with gr.Row(elem_classes=["main-interface"]):
|
| 521 |
+
# Input section
|
| 522 |
+
with gr.Column(scale=1, elem_classes=["input-section"]):
|
| 523 |
+
gr.Markdown("### 📹 Video Input")
|
| 524 |
+
|
| 525 |
+
video_input = gr.Video(
|
| 526 |
+
label="Upload Video",
|
| 527 |
+
info="Supported formats: MP4, AVI, MOV, etc.",
|
| 528 |
+
height=300
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
text_input = gr.Textbox(
|
| 532 |
+
label="🎯 Audio Description (English)",
|
| 533 |
+
placeholder="A person walks on frozen ice",
|
| 534 |
+
lines=3,
|
| 535 |
+
info="Describe the audio you want to generate (optional)"
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
neg_prompt_input = gr.Textbox(
|
| 539 |
+
label="🚫 Negative Prompt",
|
| 540 |
+
placeholder="noisy, harsh",
|
| 541 |
+
lines=2,
|
| 542 |
+
info="Describe what you want to avoid in the generated audio (optional, default: 'noisy, harsh')"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
with gr.Row():
|
| 546 |
+
guidance_scale = gr.Slider(
|
| 547 |
+
minimum=1.0,
|
| 548 |
+
maximum=10.0,
|
| 549 |
+
value=4.5,
|
| 550 |
+
step=0.1,
|
| 551 |
+
label="🎚️ CFG Scale",
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
inference_steps = gr.Slider(
|
| 555 |
+
minimum=10,
|
| 556 |
+
maximum=100,
|
| 557 |
+
value=50,
|
| 558 |
+
step=5,
|
| 559 |
+
label="⚡ Steps",
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
sample_nums = gr.Slider(
|
| 563 |
+
minimum=1,
|
| 564 |
+
maximum=6,
|
| 565 |
+
value=1,
|
| 566 |
+
step=1,
|
| 567 |
+
label="🎲 Sample Nums",
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
generate_btn = gr.Button(
|
| 571 |
+
"🎵 Generate Audio",
|
| 572 |
+
variant="primary",
|
| 573 |
+
elem_classes=["generate-btn"]
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Results section
|
| 577 |
+
with gr.Column(scale=1, elem_classes=["output-section"]):
|
| 578 |
+
gr.Markdown("### 🎥 Generated Results")
|
| 579 |
+
|
| 580 |
+
# Multi-video gallery for displaying multiple generated samples
|
| 581 |
+
with gr.Column():
|
| 582 |
+
# Primary video (Sample 1)
|
| 583 |
+
video_output_1 = gr.Video(
|
| 584 |
+
label="Sample 1",
|
| 585 |
+
height=250,
|
| 586 |
+
visible=True
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Additional videos (Samples 2-6) - initially hidden
|
| 590 |
+
with gr.Row(elem_classes=["additional-samples"]):
|
| 591 |
+
with gr.Column(scale=1):
|
| 592 |
+
video_output_2 = gr.Video(
|
| 593 |
+
label="Sample 2",
|
| 594 |
+
height=150,
|
| 595 |
+
visible=False
|
| 596 |
+
)
|
| 597 |
+
video_output_3 = gr.Video(
|
| 598 |
+
label="Sample 3",
|
| 599 |
+
height=150,
|
| 600 |
+
visible=False
|
| 601 |
+
)
|
| 602 |
+
with gr.Column(scale=1):
|
| 603 |
+
video_output_4 = gr.Video(
|
| 604 |
+
label="Sample 4",
|
| 605 |
+
height=150,
|
| 606 |
+
visible=False
|
| 607 |
+
)
|
| 608 |
+
video_output_5 = gr.Video(
|
| 609 |
+
label="Sample 5",
|
| 610 |
+
height=150,
|
| 611 |
+
visible=False
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# Sample 6 - full width
|
| 615 |
+
video_output_6 = gr.Video(
|
| 616 |
+
label="Sample 6",
|
| 617 |
+
height=150,
|
| 618 |
+
visible=False
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
result_text = gr.Textbox(
|
| 622 |
+
label="Status",
|
| 623 |
+
interactive=False,
|
| 624 |
+
lines=2
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Examples section at the bottom
|
| 628 |
+
with gr.Column(elem_classes=["examples-section"]):
|
| 629 |
+
gr.Markdown("### 🌟 Examples")
|
| 630 |
+
gr.Markdown("Click on any example to load it into the interface above")
|
| 631 |
+
|
| 632 |
+
# Define your custom examples here - 8 examples total
|
| 633 |
+
examples_data = [
|
| 634 |
+
# Example 1
|
| 635 |
+
{
|
| 636 |
+
"caption": "A person walks on frozen ice",
|
| 637 |
+
"video_path": "examples/1_video.mp4",
|
| 638 |
+
"result_path": "examples/1_result.mp4"
|
| 639 |
+
},
|
| 640 |
+
# Example 2
|
| 641 |
+
{
|
| 642 |
+
"caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.",
|
| 643 |
+
"video_path": "examples/2_video.mp4",
|
| 644 |
+
"result_path": "examples/2_result.mp4"
|
| 645 |
+
},
|
| 646 |
+
# Example 3
|
| 647 |
+
{
|
| 648 |
+
"caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.",
|
| 649 |
+
"video_path": "examples/3_video.mp4",
|
| 650 |
+
"result_path": "examples/3_result.mp4"
|
| 651 |
+
},
|
| 652 |
+
# Example 4
|
| 653 |
+
{
|
| 654 |
+
"caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.",
|
| 655 |
+
"video_path": "examples/4_video.mp4",
|
| 656 |
+
"result_path": "examples/4_result.mp4"
|
| 657 |
+
},
|
| 658 |
+
# Example 5 - Add your new examples here
|
| 659 |
+
{
|
| 660 |
+
"caption": "snow crunching under the snowboard's edge.",
|
| 661 |
+
"video_path": "examples/5_video.mp4",
|
| 662 |
+
"result_path": "examples/5_result.mp4"
|
| 663 |
+
},
|
| 664 |
+
# Example 6
|
| 665 |
+
{
|
| 666 |
+
"caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.",
|
| 667 |
+
"video_path": "examples/6_video.mp4",
|
| 668 |
+
"result_path": "examples/6_result.mp4"
|
| 669 |
+
},
|
| 670 |
+
# Example 7
|
| 671 |
+
{
|
| 672 |
+
"caption": "humming of the scooter engine accelerates slowly.",
|
| 673 |
+
"video_path": "examples/7_video.mp4",
|
| 674 |
+
"result_path": "examples/7_result.mp4"
|
| 675 |
+
},
|
| 676 |
+
# Example 8
|
| 677 |
+
{
|
| 678 |
+
"caption": "splash of water and loud thud as person hits the surface.",
|
| 679 |
+
"video_path": "examples/8_video.mp4",
|
| 680 |
+
"result_path": "examples/8_result.mp4"
|
| 681 |
+
}
|
| 682 |
+
]
|
| 683 |
+
|
| 684 |
+
# Create example grid - 4 examples per row, 2 rows total
|
| 685 |
+
example_buttons = []
|
| 686 |
+
for row in range(2): # 2 rows
|
| 687 |
+
with gr.Row(elem_classes=["example-grid-row"]):
|
| 688 |
+
for col in range(4): # 4 columns
|
| 689 |
+
idx = row * 4 + col
|
| 690 |
+
if idx < len(examples_data):
|
| 691 |
+
example = examples_data[idx]
|
| 692 |
+
|
| 693 |
+
with gr.Column(scale=1, elem_classes=["example-item"]):
|
| 694 |
+
# Video thumbnail
|
| 695 |
+
if os.path.exists(example['video_path']):
|
| 696 |
+
example_video = gr.Video(
|
| 697 |
+
value=example['video_path'],
|
| 698 |
+
label=f"Example {idx+1}",
|
| 699 |
+
interactive=False,
|
| 700 |
+
show_label=True,
|
| 701 |
+
height=180
|
| 702 |
+
)
|
| 703 |
+
else:
|
| 704 |
+
example_video = gr.HTML(f"""
|
| 705 |
+
<div style="background: #f0f0f0; padding: 15px; text-align: center; border-radius: 8px; height: 180px; display: flex; align-items: center; justify-content: center;">
|
| 706 |
+
<div>
|
| 707 |
+
<p style="color: #666; margin: 0; font-size: 12px;">📹 Video not found</p>
|
| 708 |
+
<small style="color: #999; font-size: 10px;">{example['video_path']}</small>
|
| 709 |
+
</div>
|
| 710 |
+
</div>
|
| 711 |
+
""")
|
| 712 |
+
|
| 713 |
+
# Caption (truncated for grid layout)
|
| 714 |
+
caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption']
|
| 715 |
+
gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"])
|
| 716 |
+
|
| 717 |
+
# Load button
|
| 718 |
+
example_btn = gr.Button(
|
| 719 |
+
f"Load Example {idx+1}",
|
| 720 |
+
variant="secondary",
|
| 721 |
+
size="sm"
|
| 722 |
+
)
|
| 723 |
+
example_buttons.append((example_btn, example))
|
| 724 |
+
|
| 725 |
+
# Event handlers
|
| 726 |
+
def process_inference(video_file, text_prompt, neg_prompt, guidance_scale, inference_steps, sample_nums):
|
| 727 |
+
# Generate videos
|
| 728 |
+
video_list, status_msg = infer_single_video(
|
| 729 |
+
video_file, text_prompt, neg_prompt, guidance_scale, inference_steps, int(sample_nums)
|
| 730 |
+
)
|
| 731 |
+
# Update outputs with proper visibility
|
| 732 |
+
return update_video_outputs(video_list, status_msg)
|
| 733 |
+
|
| 734 |
+
# Add dynamic visibility control based on sample_nums
|
| 735 |
+
def update_visibility(sample_nums):
|
| 736 |
+
sample_nums = int(sample_nums)
|
| 737 |
+
return [
|
| 738 |
+
gr.update(visible=True), # Sample 1 always visible
|
| 739 |
+
gr.update(visible=sample_nums >= 2), # Sample 2
|
| 740 |
+
gr.update(visible=sample_nums >= 3), # Sample 3
|
| 741 |
+
gr.update(visible=sample_nums >= 4), # Sample 4
|
| 742 |
+
gr.update(visible=sample_nums >= 5), # Sample 5
|
| 743 |
+
gr.update(visible=sample_nums >= 6), # Sample 6
|
| 744 |
+
]
|
| 745 |
+
|
| 746 |
+
# Update visibility when sample_nums changes
|
| 747 |
+
sample_nums.change(
|
| 748 |
+
fn=update_visibility,
|
| 749 |
+
inputs=[sample_nums],
|
| 750 |
+
outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6]
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
generate_btn.click(
|
| 754 |
+
fn=process_inference,
|
| 755 |
+
inputs=[video_input, text_input, neg_prompt_input, guidance_scale, inference_steps, sample_nums],
|
| 756 |
+
outputs=[
|
| 757 |
+
video_output_1, # Sample 1 value
|
| 758 |
+
video_output_2, # Sample 2 value
|
| 759 |
+
video_output_3, # Sample 3 value
|
| 760 |
+
video_output_4, # Sample 4 value
|
| 761 |
+
video_output_5, # Sample 5 value
|
| 762 |
+
video_output_6, # Sample 6 value
|
| 763 |
+
result_text
|
| 764 |
+
]
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# Add click handlers for example buttons
|
| 768 |
+
for btn, example in example_buttons:
|
| 769 |
+
def create_example_handler(ex):
|
| 770 |
+
def handler():
|
| 771 |
+
# Check if files exist, if not, return placeholder message
|
| 772 |
+
if os.path.exists(ex['video_path']):
|
| 773 |
+
video_file = ex['video_path']
|
| 774 |
+
else:
|
| 775 |
+
video_file = None
|
| 776 |
+
|
| 777 |
+
if os.path.exists(ex['result_path']):
|
| 778 |
+
result_video = ex['result_path']
|
| 779 |
+
else:
|
| 780 |
+
result_video = None
|
| 781 |
+
|
| 782 |
+
status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..."
|
| 783 |
+
if not video_file:
|
| 784 |
+
status_msg += f"\n⚠️ Video file not found: {ex['video_path']}"
|
| 785 |
+
if not result_video:
|
| 786 |
+
status_msg += f"\n⚠️ Result video not found: {ex['result_path']}"
|
| 787 |
+
|
| 788 |
+
return video_file, ex['caption'], "noisy, harsh", result_video, status_msg
|
| 789 |
+
return handler
|
| 790 |
+
|
| 791 |
+
btn.click(
|
| 792 |
+
fn=create_example_handler(example),
|
| 793 |
+
outputs=[video_input, text_input, neg_prompt_input, video_output_1, result_text]
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
# Footer
|
| 797 |
+
gr.HTML("""
|
| 798 |
+
<div class="footer-text">
|
| 799 |
+
<p>🚀 Powered by HunyuanVideo-Foley | Generate high-quality audio from video and text descriptions</p>
|
| 800 |
+
</div>
|
| 801 |
+
""")
|
| 802 |
+
|
| 803 |
+
return app
|
| 804 |
+
|
| 805 |
+
def set_manual_seed(global_seed):
|
| 806 |
+
random.seed(global_seed)
|
| 807 |
+
np.random.seed(global_seed)
|
| 808 |
+
torch.manual_seed(global_seed)
|
| 809 |
+
|
| 810 |
+
if __name__ == "__main__":
|
| 811 |
+
set_manual_seed(1)
|
| 812 |
+
# Setup logging
|
| 813 |
+
logger.remove()
|
| 814 |
+
logger.add(lambda msg: print(msg, end=''), level="INFO")
|
| 815 |
+
|
| 816 |
+
# Auto-load model
|
| 817 |
+
logger.info("Starting application and loading model...")
|
| 818 |
+
model_load_result = auto_load_models()
|
| 819 |
+
logger.info(model_load_result)
|
| 820 |
+
|
| 821 |
+
# Create and launch Gradio app
|
| 822 |
+
app = create_gradio_interface()
|
| 823 |
+
|
| 824 |
+
# Log completion status
|
| 825 |
+
if "successfully" in model_load_result:
|
| 826 |
+
logger.info("Application ready, model loaded")
|
| 827 |
+
|
| 828 |
+
app.launch(
|
| 829 |
+
server_name="0.0.0.0",
|
| 830 |
+
server_port=8080,
|
| 831 |
+
share=False,
|
| 832 |
+
debug=False,
|
| 833 |
+
show_error=True
|
| 834 |
+
)
|
HunyuanVideo-Foley/hunyuanvideo_foley/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment
|
| 3 |
+
for High-Fidelity Foley Audio Generation
|
| 4 |
+
|
| 5 |
+
This package provides tools for generating high-quality Foley audio effects
|
| 6 |
+
from video content using multimodal diffusion models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__version__ = "1.0.0"
|
| 10 |
+
__author__ = "Tencent Hunyuan Team"
|
| 11 |
+
__email__ = "hunyuan@tencent.com"
|
| 12 |
+
|
| 13 |
+
# Import main components for easy access
|
| 14 |
+
try:
|
| 15 |
+
from .utils.model_utils import load_model, denoise_process
|
| 16 |
+
from .utils.feature_utils import feature_process
|
| 17 |
+
from .utils.media_utils import merge_audio_video
|
| 18 |
+
from .utils.config_utils import AttributeDict
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"__version__",
|
| 22 |
+
"load_model",
|
| 23 |
+
"denoise_process",
|
| 24 |
+
"feature_process",
|
| 25 |
+
"merge_audio_video",
|
| 26 |
+
"AttributeDict"
|
| 27 |
+
]
|
| 28 |
+
except ImportError:
|
| 29 |
+
# Handle missing dependencies gracefully during installation
|
| 30 |
+
__all__ = ["__version__"]
|
HunyuanVideo-Foley/hunyuanvideo_foley/cli.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Command Line Interface for HunyuanVideo-Foley
|
| 4 |
+
|
| 5 |
+
Provides command-line access to the main inference functionality.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
"""Main CLI entry point."""
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description="HunyuanVideo-Foley: Generate Foley audio from video and text",
|
| 16 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 17 |
+
epilog="""
|
| 18 |
+
Examples:
|
| 19 |
+
# Single video generation
|
| 20 |
+
hunyuanvideo-foley --model_path ./models --single_video video.mp4 --single_prompt "footsteps on gravel"
|
| 21 |
+
|
| 22 |
+
# Batch processing
|
| 23 |
+
hunyuanvideo-foley --model_path ./models --csv_path batch.csv --output_dir ./outputs
|
| 24 |
+
|
| 25 |
+
# Start Gradio interface
|
| 26 |
+
hunyuanvideo-foley --gradio --model_path ./models
|
| 27 |
+
"""
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument("--model_path", type=str, required=True,
|
| 31 |
+
help="Path to the pretrained model directory")
|
| 32 |
+
parser.add_argument("--config_path", type=str,
|
| 33 |
+
default="configs/hunyuanvideo-foley-xxl.yaml",
|
| 34 |
+
help="Path to the model configuration file")
|
| 35 |
+
|
| 36 |
+
# Input options
|
| 37 |
+
group_input = parser.add_mutually_exclusive_group(required=True)
|
| 38 |
+
group_input.add_argument("--single_video", type=str,
|
| 39 |
+
help="Path to single video file for processing")
|
| 40 |
+
group_input.add_argument("--csv_path", type=str,
|
| 41 |
+
help="Path to CSV file with video paths and prompts")
|
| 42 |
+
group_input.add_argument("--gradio", action="store_true",
|
| 43 |
+
help="Launch Gradio web interface")
|
| 44 |
+
|
| 45 |
+
# Generation options
|
| 46 |
+
parser.add_argument("--single_prompt", type=str,
|
| 47 |
+
help="Text prompt for single video (required with --single_video)")
|
| 48 |
+
parser.add_argument("--output_dir", type=str, default="./outputs",
|
| 49 |
+
help="Output directory for generated audio files")
|
| 50 |
+
parser.add_argument("--guidance_scale", type=float, default=4.5,
|
| 51 |
+
help="Guidance scale for generation (default: 4.5)")
|
| 52 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
| 53 |
+
help="Number of inference steps (default: 50)")
|
| 54 |
+
parser.add_argument("--neg_prompt", type=str,
|
| 55 |
+
help="Negative prompt to avoid certain audio characteristics")
|
| 56 |
+
|
| 57 |
+
# System options
|
| 58 |
+
parser.add_argument("--device", type=str, default="auto",
|
| 59 |
+
choices=["auto", "cpu", "cuda"],
|
| 60 |
+
help="Device to use for inference")
|
| 61 |
+
parser.add_argument("--gpu_id", type=int, default=0,
|
| 62 |
+
help="GPU ID to use (default: 0)")
|
| 63 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 64 |
+
help="Random seed for reproducible generation")
|
| 65 |
+
|
| 66 |
+
args = parser.parse_args()
|
| 67 |
+
|
| 68 |
+
# Validate arguments
|
| 69 |
+
if args.single_video and not args.single_prompt:
|
| 70 |
+
parser.error("--single_prompt is required when using --single_video")
|
| 71 |
+
|
| 72 |
+
# Import here to avoid import errors if dependencies are missing
|
| 73 |
+
try:
|
| 74 |
+
if args.gradio:
|
| 75 |
+
_launch_gradio(args)
|
| 76 |
+
elif args.single_video:
|
| 77 |
+
_process_single_video(args)
|
| 78 |
+
elif args.csv_path:
|
| 79 |
+
_process_batch(args)
|
| 80 |
+
except ImportError as e:
|
| 81 |
+
print(f"Error: Missing required dependencies. Please install with: pip install hunyuanvideo-foley[all]")
|
| 82 |
+
print(f"Import error: {e}")
|
| 83 |
+
sys.exit(1)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Error: {e}")
|
| 86 |
+
sys.exit(1)
|
| 87 |
+
|
| 88 |
+
def _launch_gradio(args):
|
| 89 |
+
"""Launch Gradio web interface."""
|
| 90 |
+
import os
|
| 91 |
+
os.environ["HIFI_FOLEY_MODEL_PATH"] = args.model_path
|
| 92 |
+
|
| 93 |
+
# Import and launch gradio app
|
| 94 |
+
import subprocess
|
| 95 |
+
gradio_script = Path(__file__).parent.parent / "gradio_app.py"
|
| 96 |
+
subprocess.run([sys.executable, str(gradio_script)])
|
| 97 |
+
|
| 98 |
+
def _process_single_video(args):
|
| 99 |
+
"""Process a single video file."""
|
| 100 |
+
from . import infer
|
| 101 |
+
|
| 102 |
+
print(f"Processing video: {args.single_video}")
|
| 103 |
+
print(f"Prompt: {args.single_prompt}")
|
| 104 |
+
|
| 105 |
+
# This would need to be implemented to match the actual infer.py interface
|
| 106 |
+
# For now, redirect to the original script
|
| 107 |
+
import subprocess
|
| 108 |
+
cmd = [
|
| 109 |
+
sys.executable, "infer.py",
|
| 110 |
+
"--model_path", args.model_path,
|
| 111 |
+
"--config_path", args.config_path,
|
| 112 |
+
"--single_video", args.single_video,
|
| 113 |
+
"--single_prompt", args.single_prompt,
|
| 114 |
+
"--output_dir", args.output_dir,
|
| 115 |
+
"--guidance_scale", str(args.guidance_scale),
|
| 116 |
+
"--num_inference_steps", str(args.num_inference_steps)
|
| 117 |
+
]
|
| 118 |
+
if args.neg_prompt:
|
| 119 |
+
cmd.extend(["--neg_prompt", args.neg_prompt])
|
| 120 |
+
|
| 121 |
+
subprocess.run(cmd)
|
| 122 |
+
|
| 123 |
+
def _process_batch(args):
|
| 124 |
+
"""Process a batch of videos from CSV."""
|
| 125 |
+
import subprocess
|
| 126 |
+
cmd = [
|
| 127 |
+
sys.executable, "infer.py",
|
| 128 |
+
"--model_path", args.model_path,
|
| 129 |
+
"--config_path", args.config_path,
|
| 130 |
+
"--csv_path", args.csv_path,
|
| 131 |
+
"--output_dir", args.output_dir,
|
| 132 |
+
"--guidance_scale", str(args.guidance_scale),
|
| 133 |
+
"--num_inference_steps", str(args.num_inference_steps)
|
| 134 |
+
]
|
| 135 |
+
if args.neg_prompt:
|
| 136 |
+
cmd.extend(["--neg_prompt", args.neg_prompt])
|
| 137 |
+
|
| 138 |
+
subprocess.run(cmd)
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
HunyuanVideo-Foley/hunyuanvideo_foley/constants.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constants used throughout the HunyuanVideo-Foley project."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
# Model configuration
|
| 6 |
+
DEFAULT_AUDIO_SAMPLE_RATE = 48000
|
| 7 |
+
DEFAULT_VIDEO_FPS = 25
|
| 8 |
+
DEFAULT_AUDIO_CHANNELS = 2
|
| 9 |
+
|
| 10 |
+
# Video processing
|
| 11 |
+
MAX_VIDEO_DURATION_SECONDS = 15.0
|
| 12 |
+
MIN_VIDEO_DURATION_SECONDS = 1.0
|
| 13 |
+
|
| 14 |
+
# Audio processing
|
| 15 |
+
AUDIO_VAE_LATENT_DIM = 128
|
| 16 |
+
AUDIO_FRAME_RATE = 75 # frames per second in latent space
|
| 17 |
+
|
| 18 |
+
# Visual features
|
| 19 |
+
FPS_VISUAL: Dict[str, int] = {
|
| 20 |
+
"siglip2": 8,
|
| 21 |
+
"synchformer": 25
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# Model paths (can be overridden by environment variables)
|
| 25 |
+
DEFAULT_MODEL_PATH = "./pretrained_models/"
|
| 26 |
+
DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
|
| 27 |
+
|
| 28 |
+
# Inference parameters
|
| 29 |
+
DEFAULT_GUIDANCE_SCALE = 4.5
|
| 30 |
+
DEFAULT_NUM_INFERENCE_STEPS = 50
|
| 31 |
+
MIN_GUIDANCE_SCALE = 1.0
|
| 32 |
+
MAX_GUIDANCE_SCALE = 10.0
|
| 33 |
+
MIN_INFERENCE_STEPS = 10
|
| 34 |
+
MAX_INFERENCE_STEPS = 100
|
| 35 |
+
|
| 36 |
+
# Text processing
|
| 37 |
+
MAX_TEXT_LENGTH = 100
|
| 38 |
+
DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
|
| 39 |
+
|
| 40 |
+
# File extensions
|
| 41 |
+
SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
|
| 42 |
+
SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
|
| 43 |
+
|
| 44 |
+
# Quality settings
|
| 45 |
+
AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
|
| 46 |
+
"high": ["-b:a", "192k"],
|
| 47 |
+
"medium": ["-b:a", "128k"],
|
| 48 |
+
"low": ["-b:a", "96k"]
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Error messages
|
| 52 |
+
ERROR_MESSAGES: Dict[str, str] = {
|
| 53 |
+
"model_not_loaded": "Model is not loaded. Please load the model first.",
|
| 54 |
+
"invalid_video_format": "Unsupported video format. Supported formats: {formats}",
|
| 55 |
+
"video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
|
| 56 |
+
"ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
|
| 57 |
+
}
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/__init__.py
ADDED
|
File without changes
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.0.0"
|
| 2 |
+
|
| 3 |
+
# preserved here for legacy reasons
|
| 4 |
+
__model_version__ = "latest"
|
| 5 |
+
|
| 6 |
+
import audiotools
|
| 7 |
+
|
| 8 |
+
audiotools.ml.BaseModel.INTERN += ["dac.**"]
|
| 9 |
+
audiotools.ml.BaseModel.EXTERN += ["einops"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from . import nn
|
| 13 |
+
from . import model
|
| 14 |
+
from . import utils
|
| 15 |
+
from .model import DAC
|
| 16 |
+
from .model import DACFile
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__main__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import argbind
|
| 4 |
+
|
| 5 |
+
from .utils import download
|
| 6 |
+
from .utils.decode import decode
|
| 7 |
+
from .utils.encode import encode
|
| 8 |
+
|
| 9 |
+
STAGES = ["encode", "decode", "download"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run(stage: str):
|
| 13 |
+
"""Run stages.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
stage : str
|
| 18 |
+
Stage to run
|
| 19 |
+
"""
|
| 20 |
+
if stage not in STAGES:
|
| 21 |
+
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
|
| 22 |
+
stage_fn = globals()[stage]
|
| 23 |
+
|
| 24 |
+
if stage == "download":
|
| 25 |
+
stage_fn()
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
stage_fn()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
group = sys.argv.pop(1)
|
| 33 |
+
args = argbind.parse_args(group=group)
|
| 34 |
+
|
| 35 |
+
with argbind.scope(args):
|
| 36 |
+
run(group)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import CodecMixin
|
| 2 |
+
from .base import DACFile
|
| 3 |
+
from .dac import DAC
|
| 4 |
+
from .discriminator import Discriminator
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/base.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
from audiotools import AudioSignal
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
SUPPORTED_VERSIONS = ["1.0.0"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DACFile:
|
| 17 |
+
codes: torch.Tensor
|
| 18 |
+
|
| 19 |
+
# Metadata
|
| 20 |
+
chunk_length: int
|
| 21 |
+
original_length: int
|
| 22 |
+
input_db: float
|
| 23 |
+
channels: int
|
| 24 |
+
sample_rate: int
|
| 25 |
+
padding: bool
|
| 26 |
+
dac_version: str
|
| 27 |
+
|
| 28 |
+
def save(self, path):
|
| 29 |
+
artifacts = {
|
| 30 |
+
"codes": self.codes.numpy().astype(np.uint16),
|
| 31 |
+
"metadata": {
|
| 32 |
+
"input_db": self.input_db.numpy().astype(np.float32),
|
| 33 |
+
"original_length": self.original_length,
|
| 34 |
+
"sample_rate": self.sample_rate,
|
| 35 |
+
"chunk_length": self.chunk_length,
|
| 36 |
+
"channels": self.channels,
|
| 37 |
+
"padding": self.padding,
|
| 38 |
+
"dac_version": SUPPORTED_VERSIONS[-1],
|
| 39 |
+
},
|
| 40 |
+
}
|
| 41 |
+
path = Path(path).with_suffix(".dac")
|
| 42 |
+
with open(path, "wb") as f:
|
| 43 |
+
np.save(f, artifacts)
|
| 44 |
+
return path
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def load(cls, path):
|
| 48 |
+
artifacts = np.load(path, allow_pickle=True)[()]
|
| 49 |
+
codes = torch.from_numpy(artifacts["codes"].astype(int))
|
| 50 |
+
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
f"Given file {path} can't be loaded with this version of descript-audio-codec."
|
| 53 |
+
)
|
| 54 |
+
return cls(codes=codes, **artifacts["metadata"])
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CodecMixin:
|
| 58 |
+
@property
|
| 59 |
+
def padding(self):
|
| 60 |
+
if not hasattr(self, "_padding"):
|
| 61 |
+
self._padding = True
|
| 62 |
+
return self._padding
|
| 63 |
+
|
| 64 |
+
@padding.setter
|
| 65 |
+
def padding(self, value):
|
| 66 |
+
assert isinstance(value, bool)
|
| 67 |
+
|
| 68 |
+
layers = [
|
| 69 |
+
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
for layer in layers:
|
| 73 |
+
if value:
|
| 74 |
+
if hasattr(layer, "original_padding"):
|
| 75 |
+
layer.padding = layer.original_padding
|
| 76 |
+
else:
|
| 77 |
+
layer.original_padding = layer.padding
|
| 78 |
+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
| 79 |
+
|
| 80 |
+
self._padding = value
|
| 81 |
+
|
| 82 |
+
def get_delay(self):
|
| 83 |
+
# Any number works here, delay is invariant to input length
|
| 84 |
+
l_out = self.get_output_length(0)
|
| 85 |
+
L = l_out
|
| 86 |
+
|
| 87 |
+
layers = []
|
| 88 |
+
for layer in self.modules():
|
| 89 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 90 |
+
layers.append(layer)
|
| 91 |
+
|
| 92 |
+
for layer in reversed(layers):
|
| 93 |
+
d = layer.dilation[0]
|
| 94 |
+
k = layer.kernel_size[0]
|
| 95 |
+
s = layer.stride[0]
|
| 96 |
+
|
| 97 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
| 98 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 99 |
+
elif isinstance(layer, nn.Conv1d):
|
| 100 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 101 |
+
|
| 102 |
+
L = math.ceil(L)
|
| 103 |
+
|
| 104 |
+
l_in = L
|
| 105 |
+
|
| 106 |
+
return (l_in - l_out) // 2
|
| 107 |
+
|
| 108 |
+
def get_output_length(self, input_length):
|
| 109 |
+
L = input_length
|
| 110 |
+
# Calculate output length
|
| 111 |
+
for layer in self.modules():
|
| 112 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 113 |
+
d = layer.dilation[0]
|
| 114 |
+
k = layer.kernel_size[0]
|
| 115 |
+
s = layer.stride[0]
|
| 116 |
+
|
| 117 |
+
if isinstance(layer, nn.Conv1d):
|
| 118 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 119 |
+
elif isinstance(layer, nn.ConvTranspose1d):
|
| 120 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 121 |
+
|
| 122 |
+
L = math.floor(L)
|
| 123 |
+
return L
|
| 124 |
+
|
| 125 |
+
@torch.no_grad()
|
| 126 |
+
def compress(
|
| 127 |
+
self,
|
| 128 |
+
audio_path_or_signal: Union[str, Path, AudioSignal],
|
| 129 |
+
win_duration: float = 1.0,
|
| 130 |
+
verbose: bool = False,
|
| 131 |
+
normalize_db: float = -16,
|
| 132 |
+
n_quantizers: int = None,
|
| 133 |
+
) -> DACFile:
|
| 134 |
+
"""Processes an audio signal from a file or AudioSignal object into
|
| 135 |
+
discrete codes. This function processes the signal in short windows,
|
| 136 |
+
using constant GPU memory.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
audio_path_or_signal : Union[str, Path, AudioSignal]
|
| 141 |
+
audio signal to reconstruct
|
| 142 |
+
win_duration : float, optional
|
| 143 |
+
window duration in seconds, by default 5.0
|
| 144 |
+
verbose : bool, optional
|
| 145 |
+
by default False
|
| 146 |
+
normalize_db : float, optional
|
| 147 |
+
normalize db, by default -16
|
| 148 |
+
|
| 149 |
+
Returns
|
| 150 |
+
-------
|
| 151 |
+
DACFile
|
| 152 |
+
Object containing compressed codes and metadata
|
| 153 |
+
required for decompression
|
| 154 |
+
"""
|
| 155 |
+
audio_signal = audio_path_or_signal
|
| 156 |
+
if isinstance(audio_signal, (str, Path)):
|
| 157 |
+
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
| 158 |
+
|
| 159 |
+
self.eval()
|
| 160 |
+
original_padding = self.padding
|
| 161 |
+
original_device = audio_signal.device
|
| 162 |
+
|
| 163 |
+
audio_signal = audio_signal.clone()
|
| 164 |
+
audio_signal = audio_signal.to_mono()
|
| 165 |
+
original_sr = audio_signal.sample_rate
|
| 166 |
+
|
| 167 |
+
resample_fn = audio_signal.resample
|
| 168 |
+
loudness_fn = audio_signal.loudness
|
| 169 |
+
|
| 170 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
| 171 |
+
if audio_signal.signal_duration >= 10 * 60 * 60:
|
| 172 |
+
resample_fn = audio_signal.ffmpeg_resample
|
| 173 |
+
loudness_fn = audio_signal.ffmpeg_loudness
|
| 174 |
+
|
| 175 |
+
original_length = audio_signal.signal_length
|
| 176 |
+
resample_fn(self.sample_rate)
|
| 177 |
+
input_db = loudness_fn()
|
| 178 |
+
|
| 179 |
+
if normalize_db is not None:
|
| 180 |
+
audio_signal.normalize(normalize_db)
|
| 181 |
+
audio_signal.ensure_max_of_audio()
|
| 182 |
+
|
| 183 |
+
nb, nac, nt = audio_signal.audio_data.shape
|
| 184 |
+
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
| 185 |
+
win_duration = (
|
| 186 |
+
audio_signal.signal_duration if win_duration is None else win_duration
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if audio_signal.signal_duration <= win_duration:
|
| 190 |
+
# Unchunked compression (used if signal length < win duration)
|
| 191 |
+
self.padding = True
|
| 192 |
+
n_samples = nt
|
| 193 |
+
hop = nt
|
| 194 |
+
else:
|
| 195 |
+
# Chunked inference
|
| 196 |
+
self.padding = False
|
| 197 |
+
# Zero-pad signal on either side by the delay
|
| 198 |
+
audio_signal.zero_pad(self.delay, self.delay)
|
| 199 |
+
n_samples = int(win_duration * self.sample_rate)
|
| 200 |
+
# Round n_samples to nearest hop length multiple
|
| 201 |
+
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
| 202 |
+
hop = self.get_output_length(n_samples)
|
| 203 |
+
|
| 204 |
+
codes = []
|
| 205 |
+
range_fn = range if not verbose else tqdm.trange
|
| 206 |
+
|
| 207 |
+
for i in range_fn(0, nt, hop):
|
| 208 |
+
x = audio_signal[..., i : i + n_samples]
|
| 209 |
+
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
| 210 |
+
|
| 211 |
+
audio_data = x.audio_data.to(self.device)
|
| 212 |
+
audio_data = self.preprocess(audio_data, self.sample_rate)
|
| 213 |
+
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
| 214 |
+
codes.append(c.to(original_device))
|
| 215 |
+
chunk_length = c.shape[-1]
|
| 216 |
+
|
| 217 |
+
codes = torch.cat(codes, dim=-1)
|
| 218 |
+
|
| 219 |
+
dac_file = DACFile(
|
| 220 |
+
codes=codes,
|
| 221 |
+
chunk_length=chunk_length,
|
| 222 |
+
original_length=original_length,
|
| 223 |
+
input_db=input_db,
|
| 224 |
+
channels=nac,
|
| 225 |
+
sample_rate=original_sr,
|
| 226 |
+
padding=self.padding,
|
| 227 |
+
dac_version=SUPPORTED_VERSIONS[-1],
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
if n_quantizers is not None:
|
| 231 |
+
codes = codes[:, :n_quantizers, :]
|
| 232 |
+
|
| 233 |
+
self.padding = original_padding
|
| 234 |
+
return dac_file
|
| 235 |
+
|
| 236 |
+
@torch.no_grad()
|
| 237 |
+
def decompress(
|
| 238 |
+
self,
|
| 239 |
+
obj: Union[str, Path, DACFile],
|
| 240 |
+
verbose: bool = False,
|
| 241 |
+
) -> AudioSignal:
|
| 242 |
+
"""Reconstruct audio from a given .dac file
|
| 243 |
+
|
| 244 |
+
Parameters
|
| 245 |
+
----------
|
| 246 |
+
obj : Union[str, Path, DACFile]
|
| 247 |
+
.dac file location or corresponding DACFile object.
|
| 248 |
+
verbose : bool, optional
|
| 249 |
+
Prints progress if True, by default False
|
| 250 |
+
|
| 251 |
+
Returns
|
| 252 |
+
-------
|
| 253 |
+
AudioSignal
|
| 254 |
+
Object with the reconstructed audio
|
| 255 |
+
"""
|
| 256 |
+
self.eval()
|
| 257 |
+
if isinstance(obj, (str, Path)):
|
| 258 |
+
obj = DACFile.load(obj)
|
| 259 |
+
|
| 260 |
+
original_padding = self.padding
|
| 261 |
+
self.padding = obj.padding
|
| 262 |
+
|
| 263 |
+
range_fn = range if not verbose else tqdm.trange
|
| 264 |
+
codes = obj.codes
|
| 265 |
+
original_device = codes.device
|
| 266 |
+
chunk_length = obj.chunk_length
|
| 267 |
+
recons = []
|
| 268 |
+
|
| 269 |
+
for i in range_fn(0, codes.shape[-1], chunk_length):
|
| 270 |
+
c = codes[..., i : i + chunk_length].to(self.device)
|
| 271 |
+
z = self.quantizer.from_codes(c)[0]
|
| 272 |
+
r = self.decode(z)
|
| 273 |
+
recons.append(r.to(original_device))
|
| 274 |
+
|
| 275 |
+
recons = torch.cat(recons, dim=-1)
|
| 276 |
+
recons = AudioSignal(recons, self.sample_rate)
|
| 277 |
+
|
| 278 |
+
resample_fn = recons.resample
|
| 279 |
+
loudness_fn = recons.loudness
|
| 280 |
+
|
| 281 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
| 282 |
+
if recons.signal_duration >= 10 * 60 * 60:
|
| 283 |
+
resample_fn = recons.ffmpeg_resample
|
| 284 |
+
loudness_fn = recons.ffmpeg_loudness
|
| 285 |
+
|
| 286 |
+
if obj.input_db is not None:
|
| 287 |
+
recons.normalize(obj.input_db)
|
| 288 |
+
|
| 289 |
+
resample_fn(obj.sample_rate)
|
| 290 |
+
|
| 291 |
+
if obj.original_length is not None:
|
| 292 |
+
recons = recons[..., : obj.original_length]
|
| 293 |
+
loudness_fn()
|
| 294 |
+
recons.audio_data = recons.audio_data.reshape(
|
| 295 |
+
-1, obj.channels, obj.original_length
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
loudness_fn()
|
| 299 |
+
|
| 300 |
+
self.padding = original_padding
|
| 301 |
+
return recons
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/dac.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
from audiotools.ml import BaseModel
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .base import CodecMixin
|
| 12 |
+
from ..nn.layers import Snake1d
|
| 13 |
+
from ..nn.layers import WNConv1d
|
| 14 |
+
from ..nn.layers import WNConvTranspose1d
|
| 15 |
+
from ..nn.quantize import ResidualVectorQuantize
|
| 16 |
+
from ..nn.vae_utils import DiagonalGaussianDistribution
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def init_weights(m):
|
| 20 |
+
if isinstance(m, nn.Conv1d):
|
| 21 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 22 |
+
nn.init.constant_(m.bias, 0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ResidualUnit(nn.Module):
|
| 26 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 27 |
+
super().__init__()
|
| 28 |
+
pad = ((7 - 1) * dilation) // 2
|
| 29 |
+
self.block = nn.Sequential(
|
| 30 |
+
Snake1d(dim),
|
| 31 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 32 |
+
Snake1d(dim),
|
| 33 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
y = self.block(x)
|
| 38 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 39 |
+
if pad > 0:
|
| 40 |
+
x = x[..., pad:-pad]
|
| 41 |
+
return x + y
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class EncoderBlock(nn.Module):
|
| 45 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.block = nn.Sequential(
|
| 48 |
+
ResidualUnit(dim // 2, dilation=1),
|
| 49 |
+
ResidualUnit(dim // 2, dilation=3),
|
| 50 |
+
ResidualUnit(dim // 2, dilation=9),
|
| 51 |
+
Snake1d(dim // 2),
|
| 52 |
+
WNConv1d(
|
| 53 |
+
dim // 2,
|
| 54 |
+
dim,
|
| 55 |
+
kernel_size=2 * stride,
|
| 56 |
+
stride=stride,
|
| 57 |
+
padding=math.ceil(stride / 2),
|
| 58 |
+
),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
return self.block(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Encoder(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
d_model: int = 64,
|
| 69 |
+
strides: list = [2, 4, 8, 8],
|
| 70 |
+
d_latent: int = 64,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
# Create first convolution
|
| 74 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 75 |
+
|
| 76 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 77 |
+
for stride in strides:
|
| 78 |
+
d_model *= 2
|
| 79 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
| 80 |
+
|
| 81 |
+
# Create last convolution
|
| 82 |
+
self.block += [
|
| 83 |
+
Snake1d(d_model),
|
| 84 |
+
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Wrap black into nn.Sequential
|
| 88 |
+
self.block = nn.Sequential(*self.block)
|
| 89 |
+
self.enc_dim = d_model
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
return self.block(x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DecoderBlock(nn.Module):
|
| 96 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.block = nn.Sequential(
|
| 99 |
+
Snake1d(input_dim),
|
| 100 |
+
WNConvTranspose1d(
|
| 101 |
+
input_dim,
|
| 102 |
+
output_dim,
|
| 103 |
+
kernel_size=2 * stride,
|
| 104 |
+
stride=stride,
|
| 105 |
+
padding=math.ceil(stride / 2),
|
| 106 |
+
output_padding=stride % 2,
|
| 107 |
+
),
|
| 108 |
+
ResidualUnit(output_dim, dilation=1),
|
| 109 |
+
ResidualUnit(output_dim, dilation=3),
|
| 110 |
+
ResidualUnit(output_dim, dilation=9),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return self.block(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Decoder(nn.Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
input_channel,
|
| 121 |
+
channels,
|
| 122 |
+
rates,
|
| 123 |
+
d_out: int = 1,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
# Add first conv layer
|
| 128 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 129 |
+
|
| 130 |
+
# Add upsampling + MRF blocks
|
| 131 |
+
for i, stride in enumerate(rates):
|
| 132 |
+
input_dim = channels // 2**i
|
| 133 |
+
output_dim = channels // 2 ** (i + 1)
|
| 134 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
| 135 |
+
|
| 136 |
+
# Add final conv layer
|
| 137 |
+
layers += [
|
| 138 |
+
Snake1d(output_dim),
|
| 139 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 140 |
+
nn.Tanh(),
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
self.model = nn.Sequential(*layers)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
return self.model(x)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class DAC(BaseModel, CodecMixin):
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
encoder_dim: int = 64,
|
| 153 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
| 154 |
+
latent_dim: int = None,
|
| 155 |
+
decoder_dim: int = 1536,
|
| 156 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
| 157 |
+
n_codebooks: int = 9,
|
| 158 |
+
codebook_size: int = 1024,
|
| 159 |
+
codebook_dim: Union[int, list] = 8,
|
| 160 |
+
quantizer_dropout: bool = False,
|
| 161 |
+
sample_rate: int = 44100,
|
| 162 |
+
continuous: bool = False,
|
| 163 |
+
):
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
self.encoder_dim = encoder_dim
|
| 167 |
+
self.encoder_rates = encoder_rates
|
| 168 |
+
self.decoder_dim = decoder_dim
|
| 169 |
+
self.decoder_rates = decoder_rates
|
| 170 |
+
self.sample_rate = sample_rate
|
| 171 |
+
self.continuous = continuous
|
| 172 |
+
|
| 173 |
+
if latent_dim is None:
|
| 174 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 175 |
+
|
| 176 |
+
self.latent_dim = latent_dim
|
| 177 |
+
|
| 178 |
+
self.hop_length = np.prod(encoder_rates)
|
| 179 |
+
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
| 180 |
+
|
| 181 |
+
if not continuous:
|
| 182 |
+
self.n_codebooks = n_codebooks
|
| 183 |
+
self.codebook_size = codebook_size
|
| 184 |
+
self.codebook_dim = codebook_dim
|
| 185 |
+
self.quantizer = ResidualVectorQuantize(
|
| 186 |
+
input_dim=latent_dim,
|
| 187 |
+
n_codebooks=n_codebooks,
|
| 188 |
+
codebook_size=codebook_size,
|
| 189 |
+
codebook_dim=codebook_dim,
|
| 190 |
+
quantizer_dropout=quantizer_dropout,
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
|
| 194 |
+
self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
|
| 195 |
+
|
| 196 |
+
self.decoder = Decoder(
|
| 197 |
+
latent_dim,
|
| 198 |
+
decoder_dim,
|
| 199 |
+
decoder_rates,
|
| 200 |
+
)
|
| 201 |
+
self.sample_rate = sample_rate
|
| 202 |
+
self.apply(init_weights)
|
| 203 |
+
|
| 204 |
+
self.delay = self.get_delay()
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def dtype(self):
|
| 208 |
+
"""Get the dtype of the model parameters."""
|
| 209 |
+
# Return the dtype of the first parameter found
|
| 210 |
+
for param in self.parameters():
|
| 211 |
+
return param.dtype
|
| 212 |
+
return torch.float32 # fallback
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def device(self):
|
| 216 |
+
"""Get the device of the model parameters."""
|
| 217 |
+
# Return the device of the first parameter found
|
| 218 |
+
for param in self.parameters():
|
| 219 |
+
return param.device
|
| 220 |
+
return torch.device('cpu') # fallback
|
| 221 |
+
|
| 222 |
+
def preprocess(self, audio_data, sample_rate):
|
| 223 |
+
if sample_rate is None:
|
| 224 |
+
sample_rate = self.sample_rate
|
| 225 |
+
assert sample_rate == self.sample_rate
|
| 226 |
+
|
| 227 |
+
length = audio_data.shape[-1]
|
| 228 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
| 229 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
| 230 |
+
|
| 231 |
+
return audio_data
|
| 232 |
+
|
| 233 |
+
def encode(
|
| 234 |
+
self,
|
| 235 |
+
audio_data: torch.Tensor,
|
| 236 |
+
n_quantizers: int = None,
|
| 237 |
+
):
|
| 238 |
+
"""Encode given audio data and return quantized latent codes
|
| 239 |
+
|
| 240 |
+
Parameters
|
| 241 |
+
----------
|
| 242 |
+
audio_data : Tensor[B x 1 x T]
|
| 243 |
+
Audio data to encode
|
| 244 |
+
n_quantizers : int, optional
|
| 245 |
+
Number of quantizers to use, by default None
|
| 246 |
+
If None, all quantizers are used.
|
| 247 |
+
|
| 248 |
+
Returns
|
| 249 |
+
-------
|
| 250 |
+
dict
|
| 251 |
+
A dictionary with the following keys:
|
| 252 |
+
"z" : Tensor[B x D x T]
|
| 253 |
+
Quantized continuous representation of input
|
| 254 |
+
"codes" : Tensor[B x N x T]
|
| 255 |
+
Codebook indices for each codebook
|
| 256 |
+
(quantized discrete representation of input)
|
| 257 |
+
"latents" : Tensor[B x N*D x T]
|
| 258 |
+
Projected latents (continuous representation of input before quantization)
|
| 259 |
+
"vq/commitment_loss" : Tensor[1]
|
| 260 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 261 |
+
entries
|
| 262 |
+
"vq/codebook_loss" : Tensor[1]
|
| 263 |
+
Codebook loss to update the codebook
|
| 264 |
+
"length" : int
|
| 265 |
+
Number of samples in input audio
|
| 266 |
+
"""
|
| 267 |
+
z = self.encoder(audio_data) # [B x D x T]
|
| 268 |
+
if not self.continuous:
|
| 269 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
| 270 |
+
else:
|
| 271 |
+
z = self.quant_conv(z) # [B x 2D x T]
|
| 272 |
+
z = DiagonalGaussianDistribution(z)
|
| 273 |
+
codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
|
| 274 |
+
|
| 275 |
+
return z, codes, latents, commitment_loss, codebook_loss
|
| 276 |
+
|
| 277 |
+
def decode(self, z: torch.Tensor):
|
| 278 |
+
"""Decode given latent codes and return audio data
|
| 279 |
+
|
| 280 |
+
Parameters
|
| 281 |
+
----------
|
| 282 |
+
z : Tensor[B x D x T]
|
| 283 |
+
Quantized continuous representation of input
|
| 284 |
+
length : int, optional
|
| 285 |
+
Number of samples in output audio, by default None
|
| 286 |
+
|
| 287 |
+
Returns
|
| 288 |
+
-------
|
| 289 |
+
dict
|
| 290 |
+
A dictionary with the following keys:
|
| 291 |
+
"audio" : Tensor[B x 1 x length]
|
| 292 |
+
Decoded audio data.
|
| 293 |
+
"""
|
| 294 |
+
if not self.continuous:
|
| 295 |
+
audio = self.decoder(z)
|
| 296 |
+
else:
|
| 297 |
+
z = self.post_quant_conv(z)
|
| 298 |
+
audio = self.decoder(z)
|
| 299 |
+
|
| 300 |
+
return audio
|
| 301 |
+
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
audio_data: torch.Tensor,
|
| 305 |
+
sample_rate: int = None,
|
| 306 |
+
n_quantizers: int = None,
|
| 307 |
+
):
|
| 308 |
+
"""Model forward pass
|
| 309 |
+
|
| 310 |
+
Parameters
|
| 311 |
+
----------
|
| 312 |
+
audio_data : Tensor[B x 1 x T]
|
| 313 |
+
Audio data to encode
|
| 314 |
+
sample_rate : int, optional
|
| 315 |
+
Sample rate of audio data in Hz, by default None
|
| 316 |
+
If None, defaults to `self.sample_rate`
|
| 317 |
+
n_quantizers : int, optional
|
| 318 |
+
Number of quantizers to use, by default None.
|
| 319 |
+
If None, all quantizers are used.
|
| 320 |
+
|
| 321 |
+
Returns
|
| 322 |
+
-------
|
| 323 |
+
dict
|
| 324 |
+
A dictionary with the following keys:
|
| 325 |
+
"z" : Tensor[B x D x T]
|
| 326 |
+
Quantized continuous representation of input
|
| 327 |
+
"codes" : Tensor[B x N x T]
|
| 328 |
+
Codebook indices for each codebook
|
| 329 |
+
(quantized discrete representation of input)
|
| 330 |
+
"latents" : Tensor[B x N*D x T]
|
| 331 |
+
Projected latents (continuous representation of input before quantization)
|
| 332 |
+
"vq/commitment_loss" : Tensor[1]
|
| 333 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 334 |
+
entries
|
| 335 |
+
"vq/codebook_loss" : Tensor[1]
|
| 336 |
+
Codebook loss to update the codebook
|
| 337 |
+
"length" : int
|
| 338 |
+
Number of samples in input audio
|
| 339 |
+
"audio" : Tensor[B x 1 x length]
|
| 340 |
+
Decoded audio data.
|
| 341 |
+
"""
|
| 342 |
+
length = audio_data.shape[-1]
|
| 343 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
| 344 |
+
if not self.continuous:
|
| 345 |
+
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
| 346 |
+
|
| 347 |
+
x = self.decode(z)
|
| 348 |
+
return {
|
| 349 |
+
"audio": x[..., :length],
|
| 350 |
+
"z": z,
|
| 351 |
+
"codes": codes,
|
| 352 |
+
"latents": latents,
|
| 353 |
+
"vq/commitment_loss": commitment_loss,
|
| 354 |
+
"vq/codebook_loss": codebook_loss,
|
| 355 |
+
}
|
| 356 |
+
else:
|
| 357 |
+
posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
|
| 358 |
+
z = posterior.sample()
|
| 359 |
+
x = self.decode(z)
|
| 360 |
+
|
| 361 |
+
kl_loss = posterior.kl()
|
| 362 |
+
kl_loss = kl_loss.mean()
|
| 363 |
+
|
| 364 |
+
return {
|
| 365 |
+
"audio": x[..., :length],
|
| 366 |
+
"z": z,
|
| 367 |
+
"kl_loss": kl_loss,
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
import numpy as np
|
| 373 |
+
from functools import partial
|
| 374 |
+
|
| 375 |
+
model = DAC().to("cpu")
|
| 376 |
+
|
| 377 |
+
for n, m in model.named_modules():
|
| 378 |
+
o = m.extra_repr()
|
| 379 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
| 380 |
+
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
| 381 |
+
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
| 382 |
+
print(model)
|
| 383 |
+
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
| 384 |
+
|
| 385 |
+
length = 88200 * 2
|
| 386 |
+
x = torch.randn(1, 1, length).to(model.device)
|
| 387 |
+
x.requires_grad_(True)
|
| 388 |
+
x.retain_grad()
|
| 389 |
+
|
| 390 |
+
# Make a forward pass
|
| 391 |
+
out = model(x)["audio"]
|
| 392 |
+
print("Input shape:", x.shape)
|
| 393 |
+
print("Output shape:", out.shape)
|
| 394 |
+
|
| 395 |
+
# Create gradient variable
|
| 396 |
+
grad = torch.zeros_like(out)
|
| 397 |
+
grad[:, :, grad.shape[-1] // 2] = 1
|
| 398 |
+
|
| 399 |
+
# Make a backward pass
|
| 400 |
+
out.backward(grad)
|
| 401 |
+
|
| 402 |
+
# Check non-zero values
|
| 403 |
+
gradmap = x.grad.squeeze(0)
|
| 404 |
+
gradmap = (gradmap != 0).sum(0) # sum across features
|
| 405 |
+
rf = (gradmap != 0).sum()
|
| 406 |
+
|
| 407 |
+
print(f"Receptive field: {rf.item()}")
|
| 408 |
+
|
| 409 |
+
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
|
| 410 |
+
model.decompress(model.compress(x, verbose=True), verbose=True)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/discriminator.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from audiotools import AudioSignal
|
| 5 |
+
from audiotools import ml
|
| 6 |
+
from audiotools import STFTParams
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def WNConv1d(*args, **kwargs):
|
| 12 |
+
act = kwargs.pop("act", True)
|
| 13 |
+
conv = weight_norm(nn.Conv1d(*args, **kwargs))
|
| 14 |
+
if not act:
|
| 15 |
+
return conv
|
| 16 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def WNConv2d(*args, **kwargs):
|
| 20 |
+
act = kwargs.pop("act", True)
|
| 21 |
+
conv = weight_norm(nn.Conv2d(*args, **kwargs))
|
| 22 |
+
if not act:
|
| 23 |
+
return conv
|
| 24 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MPD(nn.Module):
|
| 28 |
+
def __init__(self, period):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.period = period
|
| 31 |
+
self.convs = nn.ModuleList(
|
| 32 |
+
[
|
| 33 |
+
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
|
| 34 |
+
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
|
| 35 |
+
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
|
| 36 |
+
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
|
| 37 |
+
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
self.conv_post = WNConv2d(
|
| 41 |
+
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def pad_to_period(self, x):
|
| 45 |
+
t = x.shape[-1]
|
| 46 |
+
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
fmap = []
|
| 51 |
+
|
| 52 |
+
x = self.pad_to_period(x)
|
| 53 |
+
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
|
| 54 |
+
|
| 55 |
+
for layer in self.convs:
|
| 56 |
+
x = layer(x)
|
| 57 |
+
fmap.append(x)
|
| 58 |
+
|
| 59 |
+
x = self.conv_post(x)
|
| 60 |
+
fmap.append(x)
|
| 61 |
+
|
| 62 |
+
return fmap
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MSD(nn.Module):
|
| 66 |
+
def __init__(self, rate: int = 1, sample_rate: int = 44100):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.convs = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
WNConv1d(1, 16, 15, 1, padding=7),
|
| 71 |
+
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
|
| 72 |
+
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
|
| 73 |
+
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
|
| 74 |
+
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
|
| 75 |
+
WNConv1d(1024, 1024, 5, 1, padding=2),
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
|
| 79 |
+
self.sample_rate = sample_rate
|
| 80 |
+
self.rate = rate
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
x = AudioSignal(x, self.sample_rate)
|
| 84 |
+
x.resample(self.sample_rate // self.rate)
|
| 85 |
+
x = x.audio_data
|
| 86 |
+
|
| 87 |
+
fmap = []
|
| 88 |
+
|
| 89 |
+
for l in self.convs:
|
| 90 |
+
x = l(x)
|
| 91 |
+
fmap.append(x)
|
| 92 |
+
x = self.conv_post(x)
|
| 93 |
+
fmap.append(x)
|
| 94 |
+
|
| 95 |
+
return fmap
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MRD(nn.Module):
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
window_length: int,
|
| 105 |
+
hop_factor: float = 0.25,
|
| 106 |
+
sample_rate: int = 44100,
|
| 107 |
+
bands: list = BANDS,
|
| 108 |
+
):
|
| 109 |
+
"""Complex multi-band spectrogram discriminator.
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
window_length : int
|
| 113 |
+
Window length of STFT.
|
| 114 |
+
hop_factor : float, optional
|
| 115 |
+
Hop factor of the STFT, defaults to ``0.25 * window_length``.
|
| 116 |
+
sample_rate : int, optional
|
| 117 |
+
Sampling rate of audio in Hz, by default 44100
|
| 118 |
+
bands : list, optional
|
| 119 |
+
Bands to run discriminator over.
|
| 120 |
+
"""
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
self.window_length = window_length
|
| 124 |
+
self.hop_factor = hop_factor
|
| 125 |
+
self.sample_rate = sample_rate
|
| 126 |
+
self.stft_params = STFTParams(
|
| 127 |
+
window_length=window_length,
|
| 128 |
+
hop_length=int(window_length * hop_factor),
|
| 129 |
+
match_stride=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
n_fft = window_length // 2 + 1
|
| 133 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 134 |
+
self.bands = bands
|
| 135 |
+
|
| 136 |
+
ch = 32
|
| 137 |
+
convs = lambda: nn.ModuleList(
|
| 138 |
+
[
|
| 139 |
+
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
|
| 140 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 141 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 142 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 143 |
+
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 147 |
+
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
|
| 148 |
+
|
| 149 |
+
def spectrogram(self, x):
|
| 150 |
+
x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
|
| 151 |
+
x = torch.view_as_real(x.stft())
|
| 152 |
+
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
|
| 153 |
+
# Split into bands
|
| 154 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 155 |
+
return x_bands
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
x_bands = self.spectrogram(x)
|
| 159 |
+
fmap = []
|
| 160 |
+
|
| 161 |
+
x = []
|
| 162 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 163 |
+
for layer in stack:
|
| 164 |
+
band = layer(band)
|
| 165 |
+
fmap.append(band)
|
| 166 |
+
x.append(band)
|
| 167 |
+
|
| 168 |
+
x = torch.cat(x, dim=-1)
|
| 169 |
+
x = self.conv_post(x)
|
| 170 |
+
fmap.append(x)
|
| 171 |
+
|
| 172 |
+
return fmap
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class Discriminator(ml.BaseModel):
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
rates: list = [],
|
| 179 |
+
periods: list = [2, 3, 5, 7, 11],
|
| 180 |
+
fft_sizes: list = [2048, 1024, 512],
|
| 181 |
+
sample_rate: int = 44100,
|
| 182 |
+
bands: list = BANDS,
|
| 183 |
+
):
|
| 184 |
+
"""Discriminator that combines multiple discriminators.
|
| 185 |
+
|
| 186 |
+
Parameters
|
| 187 |
+
----------
|
| 188 |
+
rates : list, optional
|
| 189 |
+
sampling rates (in Hz) to run MSD at, by default []
|
| 190 |
+
If empty, MSD is not used.
|
| 191 |
+
periods : list, optional
|
| 192 |
+
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
|
| 193 |
+
fft_sizes : list, optional
|
| 194 |
+
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
|
| 195 |
+
sample_rate : int, optional
|
| 196 |
+
Sampling rate of audio in Hz, by default 44100
|
| 197 |
+
bands : list, optional
|
| 198 |
+
Bands to run MRD at, by default `BANDS`
|
| 199 |
+
"""
|
| 200 |
+
super().__init__()
|
| 201 |
+
discs = []
|
| 202 |
+
discs += [MPD(p) for p in periods]
|
| 203 |
+
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
|
| 204 |
+
discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
|
| 205 |
+
self.discriminators = nn.ModuleList(discs)
|
| 206 |
+
|
| 207 |
+
def preprocess(self, y):
|
| 208 |
+
# Remove DC offset
|
| 209 |
+
y = y - y.mean(dim=-1, keepdims=True)
|
| 210 |
+
# Peak normalize the volume of input audio
|
| 211 |
+
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 212 |
+
return y
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
x = self.preprocess(x)
|
| 216 |
+
fmaps = [d(x) for d in self.discriminators]
|
| 217 |
+
return fmaps
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
disc = Discriminator()
|
| 222 |
+
x = torch.zeros(1, 1, 44100)
|
| 223 |
+
results = disc(x)
|
| 224 |
+
for i, result in enumerate(results):
|
| 225 |
+
print(f"disc{i}")
|
| 226 |
+
for i, r in enumerate(result):
|
| 227 |
+
print(r.shape, r.mean(), r.min(), r.max())
|
| 228 |
+
print()
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import layers
|
| 2 |
+
from . import loss
|
| 3 |
+
from . import quantize
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/layers.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch.nn.utils import weight_norm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def WNConv1d(*args, **kwargs):
|
| 10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Scripting this brings model speed up 1.4x
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
def snake(x, alpha):
|
| 20 |
+
shape = x.shape
|
| 21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 23 |
+
x = x.reshape(shape)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Snake1d(nn.Module):
|
| 28 |
+
def __init__(self, channels):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return snake(x, self.alpha)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/loss.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from audiotools import AudioSignal
|
| 7 |
+
from audiotools import STFTParams
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class L1Loss(nn.L1Loss):
|
| 12 |
+
"""L1 Loss between AudioSignals. Defaults
|
| 13 |
+
to comparing ``audio_data``, but any
|
| 14 |
+
attribute of an AudioSignal can be used.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
attribute : str, optional
|
| 19 |
+
Attribute of signal to compare, defaults to ``audio_data``.
|
| 20 |
+
weight : float, optional
|
| 21 |
+
Weight of this loss, defaults to 1.0.
|
| 22 |
+
|
| 23 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
|
| 27 |
+
self.attribute = attribute
|
| 28 |
+
self.weight = weight
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
| 32 |
+
"""
|
| 33 |
+
Parameters
|
| 34 |
+
----------
|
| 35 |
+
x : AudioSignal
|
| 36 |
+
Estimate AudioSignal
|
| 37 |
+
y : AudioSignal
|
| 38 |
+
Reference AudioSignal
|
| 39 |
+
|
| 40 |
+
Returns
|
| 41 |
+
-------
|
| 42 |
+
torch.Tensor
|
| 43 |
+
L1 loss between AudioSignal attributes.
|
| 44 |
+
"""
|
| 45 |
+
if isinstance(x, AudioSignal):
|
| 46 |
+
x = getattr(x, self.attribute)
|
| 47 |
+
y = getattr(y, self.attribute)
|
| 48 |
+
return super().forward(x, y)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SISDRLoss(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
| 54 |
+
of estimated and reference audio signals or aligned features.
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
scaling : int, optional
|
| 59 |
+
Whether to use scale-invariant (True) or
|
| 60 |
+
signal-to-noise ratio (False), by default True
|
| 61 |
+
reduction : str, optional
|
| 62 |
+
How to reduce across the batch (either 'mean',
|
| 63 |
+
'sum', or none).], by default ' mean'
|
| 64 |
+
zero_mean : int, optional
|
| 65 |
+
Zero mean the references and estimates before
|
| 66 |
+
computing the loss, by default True
|
| 67 |
+
clip_min : int, optional
|
| 68 |
+
The minimum possible loss value. Helps network
|
| 69 |
+
to not focus on making already good examples better, by default None
|
| 70 |
+
weight : float, optional
|
| 71 |
+
Weight of this loss, defaults to 1.0.
|
| 72 |
+
|
| 73 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
scaling: int = True,
|
| 79 |
+
reduction: str = "mean",
|
| 80 |
+
zero_mean: int = True,
|
| 81 |
+
clip_min: int = None,
|
| 82 |
+
weight: float = 1.0,
|
| 83 |
+
):
|
| 84 |
+
self.scaling = scaling
|
| 85 |
+
self.reduction = reduction
|
| 86 |
+
self.zero_mean = zero_mean
|
| 87 |
+
self.clip_min = clip_min
|
| 88 |
+
self.weight = weight
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
| 92 |
+
eps = 1e-8
|
| 93 |
+
# nb, nc, nt
|
| 94 |
+
if isinstance(x, AudioSignal):
|
| 95 |
+
references = x.audio_data
|
| 96 |
+
estimates = y.audio_data
|
| 97 |
+
else:
|
| 98 |
+
references = x
|
| 99 |
+
estimates = y
|
| 100 |
+
|
| 101 |
+
nb = references.shape[0]
|
| 102 |
+
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
|
| 103 |
+
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
|
| 104 |
+
|
| 105 |
+
# samples now on axis 1
|
| 106 |
+
if self.zero_mean:
|
| 107 |
+
mean_reference = references.mean(dim=1, keepdim=True)
|
| 108 |
+
mean_estimate = estimates.mean(dim=1, keepdim=True)
|
| 109 |
+
else:
|
| 110 |
+
mean_reference = 0
|
| 111 |
+
mean_estimate = 0
|
| 112 |
+
|
| 113 |
+
_references = references - mean_reference
|
| 114 |
+
_estimates = estimates - mean_estimate
|
| 115 |
+
|
| 116 |
+
references_projection = (_references**2).sum(dim=-2) + eps
|
| 117 |
+
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
|
| 118 |
+
|
| 119 |
+
scale = (
|
| 120 |
+
(references_on_estimates / references_projection).unsqueeze(1)
|
| 121 |
+
if self.scaling
|
| 122 |
+
else 1
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
e_true = scale * _references
|
| 126 |
+
e_res = _estimates - e_true
|
| 127 |
+
|
| 128 |
+
signal = (e_true**2).sum(dim=1)
|
| 129 |
+
noise = (e_res**2).sum(dim=1)
|
| 130 |
+
sdr = -10 * torch.log10(signal / noise + eps)
|
| 131 |
+
|
| 132 |
+
if self.clip_min is not None:
|
| 133 |
+
sdr = torch.clamp(sdr, min=self.clip_min)
|
| 134 |
+
|
| 135 |
+
if self.reduction == "mean":
|
| 136 |
+
sdr = sdr.mean()
|
| 137 |
+
elif self.reduction == "sum":
|
| 138 |
+
sdr = sdr.sum()
|
| 139 |
+
return sdr
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class MultiScaleSTFTLoss(nn.Module):
|
| 143 |
+
"""Computes the multi-scale STFT loss from [1].
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
window_lengths : List[int], optional
|
| 148 |
+
Length of each window of each STFT, by default [2048, 512]
|
| 149 |
+
loss_fn : typing.Callable, optional
|
| 150 |
+
How to compare each loss, by default nn.L1Loss()
|
| 151 |
+
clamp_eps : float, optional
|
| 152 |
+
Clamp on the log magnitude, below, by default 1e-5
|
| 153 |
+
mag_weight : float, optional
|
| 154 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
| 155 |
+
log_weight : float, optional
|
| 156 |
+
Weight of log magnitude portion of loss, by default 1.0
|
| 157 |
+
pow : float, optional
|
| 158 |
+
Power to raise magnitude to before taking log, by default 2.0
|
| 159 |
+
weight : float, optional
|
| 160 |
+
Weight of this loss, by default 1.0
|
| 161 |
+
match_stride : bool, optional
|
| 162 |
+
Whether to match the stride of convolutional layers, by default False
|
| 163 |
+
|
| 164 |
+
References
|
| 165 |
+
----------
|
| 166 |
+
|
| 167 |
+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
| 168 |
+
"DDSP: Differentiable Digital Signal Processing."
|
| 169 |
+
International Conference on Learning Representations. 2019.
|
| 170 |
+
|
| 171 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
window_lengths: List[int] = [2048, 512],
|
| 177 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
| 178 |
+
clamp_eps: float = 1e-5,
|
| 179 |
+
mag_weight: float = 1.0,
|
| 180 |
+
log_weight: float = 1.0,
|
| 181 |
+
pow: float = 2.0,
|
| 182 |
+
weight: float = 1.0,
|
| 183 |
+
match_stride: bool = False,
|
| 184 |
+
window_type: str = None,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.stft_params = [
|
| 188 |
+
STFTParams(
|
| 189 |
+
window_length=w,
|
| 190 |
+
hop_length=w // 4,
|
| 191 |
+
match_stride=match_stride,
|
| 192 |
+
window_type=window_type,
|
| 193 |
+
)
|
| 194 |
+
for w in window_lengths
|
| 195 |
+
]
|
| 196 |
+
self.loss_fn = loss_fn
|
| 197 |
+
self.log_weight = log_weight
|
| 198 |
+
self.mag_weight = mag_weight
|
| 199 |
+
self.clamp_eps = clamp_eps
|
| 200 |
+
self.weight = weight
|
| 201 |
+
self.pow = pow
|
| 202 |
+
|
| 203 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
| 204 |
+
"""Computes multi-scale STFT between an estimate and a reference
|
| 205 |
+
signal.
|
| 206 |
+
|
| 207 |
+
Parameters
|
| 208 |
+
----------
|
| 209 |
+
x : AudioSignal
|
| 210 |
+
Estimate signal
|
| 211 |
+
y : AudioSignal
|
| 212 |
+
Reference signal
|
| 213 |
+
|
| 214 |
+
Returns
|
| 215 |
+
-------
|
| 216 |
+
torch.Tensor
|
| 217 |
+
Multi-scale STFT loss.
|
| 218 |
+
"""
|
| 219 |
+
loss = 0.0
|
| 220 |
+
for s in self.stft_params:
|
| 221 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
| 222 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
| 223 |
+
loss += self.log_weight * self.loss_fn(
|
| 224 |
+
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
| 225 |
+
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
| 226 |
+
)
|
| 227 |
+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
| 228 |
+
return loss
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class MelSpectrogramLoss(nn.Module):
|
| 232 |
+
"""Compute distance between mel spectrograms. Can be used
|
| 233 |
+
in a multi-scale way.
|
| 234 |
+
|
| 235 |
+
Parameters
|
| 236 |
+
----------
|
| 237 |
+
n_mels : List[int]
|
| 238 |
+
Number of mels per STFT, by default [150, 80],
|
| 239 |
+
window_lengths : List[int], optional
|
| 240 |
+
Length of each window of each STFT, by default [2048, 512]
|
| 241 |
+
loss_fn : typing.Callable, optional
|
| 242 |
+
How to compare each loss, by default nn.L1Loss()
|
| 243 |
+
clamp_eps : float, optional
|
| 244 |
+
Clamp on the log magnitude, below, by default 1e-5
|
| 245 |
+
mag_weight : float, optional
|
| 246 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
| 247 |
+
log_weight : float, optional
|
| 248 |
+
Weight of log magnitude portion of loss, by default 1.0
|
| 249 |
+
pow : float, optional
|
| 250 |
+
Power to raise magnitude to before taking log, by default 2.0
|
| 251 |
+
weight : float, optional
|
| 252 |
+
Weight of this loss, by default 1.0
|
| 253 |
+
match_stride : bool, optional
|
| 254 |
+
Whether to match the stride of convolutional layers, by default False
|
| 255 |
+
|
| 256 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
n_mels: List[int] = [150, 80],
|
| 262 |
+
window_lengths: List[int] = [2048, 512],
|
| 263 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
| 264 |
+
clamp_eps: float = 1e-5,
|
| 265 |
+
mag_weight: float = 1.0,
|
| 266 |
+
log_weight: float = 1.0,
|
| 267 |
+
pow: float = 2.0,
|
| 268 |
+
weight: float = 1.0,
|
| 269 |
+
match_stride: bool = False,
|
| 270 |
+
mel_fmin: List[float] = [0.0, 0.0],
|
| 271 |
+
mel_fmax: List[float] = [None, None],
|
| 272 |
+
window_type: str = None,
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.stft_params = [
|
| 276 |
+
STFTParams(
|
| 277 |
+
window_length=w,
|
| 278 |
+
hop_length=w // 4,
|
| 279 |
+
match_stride=match_stride,
|
| 280 |
+
window_type=window_type,
|
| 281 |
+
)
|
| 282 |
+
for w in window_lengths
|
| 283 |
+
]
|
| 284 |
+
self.n_mels = n_mels
|
| 285 |
+
self.loss_fn = loss_fn
|
| 286 |
+
self.clamp_eps = clamp_eps
|
| 287 |
+
self.log_weight = log_weight
|
| 288 |
+
self.mag_weight = mag_weight
|
| 289 |
+
self.weight = weight
|
| 290 |
+
self.mel_fmin = mel_fmin
|
| 291 |
+
self.mel_fmax = mel_fmax
|
| 292 |
+
self.pow = pow
|
| 293 |
+
|
| 294 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
| 295 |
+
"""Computes mel loss between an estimate and a reference
|
| 296 |
+
signal.
|
| 297 |
+
|
| 298 |
+
Parameters
|
| 299 |
+
----------
|
| 300 |
+
x : AudioSignal
|
| 301 |
+
Estimate signal
|
| 302 |
+
y : AudioSignal
|
| 303 |
+
Reference signal
|
| 304 |
+
|
| 305 |
+
Returns
|
| 306 |
+
-------
|
| 307 |
+
torch.Tensor
|
| 308 |
+
Mel loss.
|
| 309 |
+
"""
|
| 310 |
+
loss = 0.0
|
| 311 |
+
for n_mels, fmin, fmax, s in zip(
|
| 312 |
+
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
| 313 |
+
):
|
| 314 |
+
kwargs = {
|
| 315 |
+
"window_length": s.window_length,
|
| 316 |
+
"hop_length": s.hop_length,
|
| 317 |
+
"window_type": s.window_type,
|
| 318 |
+
}
|
| 319 |
+
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
| 320 |
+
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
| 321 |
+
|
| 322 |
+
loss += self.log_weight * self.loss_fn(
|
| 323 |
+
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
| 324 |
+
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
| 325 |
+
)
|
| 326 |
+
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
|
| 327 |
+
return loss
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class GANLoss(nn.Module):
|
| 331 |
+
"""
|
| 332 |
+
Computes a discriminator loss, given a discriminator on
|
| 333 |
+
generated waveforms/spectrograms compared to ground truth
|
| 334 |
+
waveforms/spectrograms. Computes the loss for both the
|
| 335 |
+
discriminator and the generator in separate functions.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(self, discriminator):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.discriminator = discriminator
|
| 341 |
+
|
| 342 |
+
def forward(self, fake, real):
|
| 343 |
+
d_fake = self.discriminator(fake.audio_data)
|
| 344 |
+
d_real = self.discriminator(real.audio_data)
|
| 345 |
+
return d_fake, d_real
|
| 346 |
+
|
| 347 |
+
def discriminator_loss(self, fake, real):
|
| 348 |
+
d_fake, d_real = self.forward(fake.clone().detach(), real)
|
| 349 |
+
|
| 350 |
+
loss_d = 0
|
| 351 |
+
for x_fake, x_real in zip(d_fake, d_real):
|
| 352 |
+
loss_d += torch.mean(x_fake[-1] ** 2)
|
| 353 |
+
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
| 354 |
+
return loss_d
|
| 355 |
+
|
| 356 |
+
def generator_loss(self, fake, real):
|
| 357 |
+
d_fake, d_real = self.forward(fake, real)
|
| 358 |
+
|
| 359 |
+
loss_g = 0
|
| 360 |
+
for x_fake in d_fake:
|
| 361 |
+
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
| 362 |
+
|
| 363 |
+
loss_feature = 0
|
| 364 |
+
|
| 365 |
+
for i in range(len(d_fake)):
|
| 366 |
+
for j in range(len(d_fake[i]) - 1):
|
| 367 |
+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
| 368 |
+
return loss_g, loss_feature
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/quantize.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
from .layers import WNConv1d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VectorQuantize(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Implementation of VQ similar to Karpathy's repo:
|
| 16 |
+
https://github.com/karpathy/deep-vector-quantization
|
| 17 |
+
Additionally uses following tricks from Improved VQGAN
|
| 18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 20 |
+
for improved codebook usage
|
| 21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 22 |
+
improves training stability
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.codebook_size = codebook_size
|
| 28 |
+
self.codebook_dim = codebook_dim
|
| 29 |
+
|
| 30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, z):
|
| 35 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 36 |
+
the corresponding codebook vectors
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
z : Tensor[B x D x T]
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
Tensor[B x D x T]
|
| 45 |
+
Quantized continuous representation of input
|
| 46 |
+
Tensor[1]
|
| 47 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 48 |
+
entries
|
| 49 |
+
Tensor[1]
|
| 50 |
+
Codebook loss to update the codebook
|
| 51 |
+
Tensor[B x T]
|
| 52 |
+
Codebook indices (quantized discrete representation of input)
|
| 53 |
+
Tensor[B x D x T]
|
| 54 |
+
Projected latents (continuous representation of input before quantization)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 58 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 59 |
+
z_q, indices = self.decode_latents(z_e)
|
| 60 |
+
|
| 61 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 62 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 63 |
+
|
| 64 |
+
z_q = (
|
| 65 |
+
z_e + (z_q - z_e).detach()
|
| 66 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 67 |
+
|
| 68 |
+
z_q = self.out_proj(z_q)
|
| 69 |
+
|
| 70 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 71 |
+
|
| 72 |
+
def embed_code(self, embed_id):
|
| 73 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 74 |
+
|
| 75 |
+
def decode_code(self, embed_id):
|
| 76 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 77 |
+
|
| 78 |
+
def decode_latents(self, latents):
|
| 79 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 80 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 81 |
+
|
| 82 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 83 |
+
encodings = F.normalize(encodings)
|
| 84 |
+
codebook = F.normalize(codebook)
|
| 85 |
+
|
| 86 |
+
# Compute euclidean distance with codebook
|
| 87 |
+
dist = (
|
| 88 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 89 |
+
- 2 * encodings @ codebook.t()
|
| 90 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 91 |
+
)
|
| 92 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 93 |
+
z_q = self.decode_code(indices)
|
| 94 |
+
return z_q, indices
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ResidualVectorQuantize(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 100 |
+
https://arxiv.org/abs/2107.03312
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
input_dim: int = 512,
|
| 106 |
+
n_codebooks: int = 9,
|
| 107 |
+
codebook_size: int = 1024,
|
| 108 |
+
codebook_dim: Union[int, list] = 8,
|
| 109 |
+
quantizer_dropout: float = 0.0,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
if isinstance(codebook_dim, int):
|
| 113 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 114 |
+
|
| 115 |
+
self.n_codebooks = n_codebooks
|
| 116 |
+
self.codebook_dim = codebook_dim
|
| 117 |
+
self.codebook_size = codebook_size
|
| 118 |
+
|
| 119 |
+
self.quantizers = nn.ModuleList(
|
| 120 |
+
[
|
| 121 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
| 122 |
+
for i in range(n_codebooks)
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
self.quantizer_dropout = quantizer_dropout
|
| 126 |
+
|
| 127 |
+
def forward(self, z, n_quantizers: int = None):
|
| 128 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 129 |
+
the corresponding codebook vectors
|
| 130 |
+
Parameters
|
| 131 |
+
----------
|
| 132 |
+
z : Tensor[B x D x T]
|
| 133 |
+
n_quantizers : int, optional
|
| 134 |
+
No. of quantizers to use
|
| 135 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 136 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 137 |
+
when in training mode, and a random number of quantizers is used.
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
dict
|
| 141 |
+
A dictionary with the following keys:
|
| 142 |
+
|
| 143 |
+
"z" : Tensor[B x D x T]
|
| 144 |
+
Quantized continuous representation of input
|
| 145 |
+
"codes" : Tensor[B x N x T]
|
| 146 |
+
Codebook indices for each codebook
|
| 147 |
+
(quantized discrete representation of input)
|
| 148 |
+
"latents" : Tensor[B x N*D x T]
|
| 149 |
+
Projected latents (continuous representation of input before quantization)
|
| 150 |
+
"vq/commitment_loss" : Tensor[1]
|
| 151 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 152 |
+
entries
|
| 153 |
+
"vq/codebook_loss" : Tensor[1]
|
| 154 |
+
Codebook loss to update the codebook
|
| 155 |
+
"""
|
| 156 |
+
z_q = 0
|
| 157 |
+
residual = z
|
| 158 |
+
commitment_loss = 0
|
| 159 |
+
codebook_loss = 0
|
| 160 |
+
|
| 161 |
+
codebook_indices = []
|
| 162 |
+
latents = []
|
| 163 |
+
|
| 164 |
+
if n_quantizers is None:
|
| 165 |
+
n_quantizers = self.n_codebooks
|
| 166 |
+
if self.training:
|
| 167 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 168 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 169 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 170 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 171 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 172 |
+
|
| 173 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 174 |
+
if self.training is False and i >= n_quantizers:
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 178 |
+
residual
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Create mask to apply quantizer dropout
|
| 182 |
+
mask = (
|
| 183 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 184 |
+
)
|
| 185 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 186 |
+
residual = residual - z_q_i
|
| 187 |
+
|
| 188 |
+
# Sum losses
|
| 189 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 190 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 191 |
+
|
| 192 |
+
codebook_indices.append(indices_i)
|
| 193 |
+
latents.append(z_e_i)
|
| 194 |
+
|
| 195 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 196 |
+
latents = torch.cat(latents, dim=1)
|
| 197 |
+
|
| 198 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 199 |
+
|
| 200 |
+
def from_codes(self, codes: torch.Tensor):
|
| 201 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 202 |
+
Parameters
|
| 203 |
+
----------
|
| 204 |
+
codes : Tensor[B x N x T]
|
| 205 |
+
Quantized discrete representation of input
|
| 206 |
+
Returns
|
| 207 |
+
-------
|
| 208 |
+
Tensor[B x D x T]
|
| 209 |
+
Quantized continuous representation of input
|
| 210 |
+
"""
|
| 211 |
+
z_q = 0.0
|
| 212 |
+
z_p = []
|
| 213 |
+
n_codebooks = codes.shape[1]
|
| 214 |
+
for i in range(n_codebooks):
|
| 215 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 216 |
+
z_p.append(z_p_i)
|
| 217 |
+
|
| 218 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 219 |
+
z_q = z_q + z_q_i
|
| 220 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 221 |
+
|
| 222 |
+
def from_latents(self, latents: torch.Tensor):
|
| 223 |
+
"""Given the unquantized latents, reconstruct the
|
| 224 |
+
continuous representation after quantization.
|
| 225 |
+
|
| 226 |
+
Parameters
|
| 227 |
+
----------
|
| 228 |
+
latents : Tensor[B x N x T]
|
| 229 |
+
Continuous representation of input after projection
|
| 230 |
+
|
| 231 |
+
Returns
|
| 232 |
+
-------
|
| 233 |
+
Tensor[B x D x T]
|
| 234 |
+
Quantized representation of full-projected space
|
| 235 |
+
Tensor[B x D x T]
|
| 236 |
+
Quantized representation of latent space
|
| 237 |
+
"""
|
| 238 |
+
z_q = 0
|
| 239 |
+
z_p = []
|
| 240 |
+
codes = []
|
| 241 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 242 |
+
|
| 243 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 244 |
+
0
|
| 245 |
+
]
|
| 246 |
+
for i in range(n_codebooks):
|
| 247 |
+
j, k = dims[i], dims[i + 1]
|
| 248 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 249 |
+
z_p.append(z_p_i)
|
| 250 |
+
codes.append(codes_i)
|
| 251 |
+
|
| 252 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 253 |
+
z_q = z_q + z_q_i
|
| 254 |
+
|
| 255 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
| 260 |
+
x = torch.randn(16, 512, 80)
|
| 261 |
+
y = rvq(x)
|
| 262 |
+
print(y["latents"].shape)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AbstractDistribution:
|
| 6 |
+
def sample(self):
|
| 7 |
+
raise NotImplementedError()
|
| 8 |
+
|
| 9 |
+
def mode(self):
|
| 10 |
+
raise NotImplementedError()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiracDistribution(AbstractDistribution):
|
| 14 |
+
def __init__(self, value):
|
| 15 |
+
self.value = value
|
| 16 |
+
|
| 17 |
+
def sample(self):
|
| 18 |
+
return self.value
|
| 19 |
+
|
| 20 |
+
def mode(self):
|
| 21 |
+
return self.value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DiagonalGaussianDistribution(object):
|
| 25 |
+
def __init__(self, parameters, deterministic=False):
|
| 26 |
+
self.parameters = parameters
|
| 27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 29 |
+
self.deterministic = deterministic
|
| 30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 31 |
+
self.var = torch.exp(self.logvar)
|
| 32 |
+
if self.deterministic:
|
| 33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 34 |
+
|
| 35 |
+
def sample(self):
|
| 36 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
def kl(self, other=None):
|
| 40 |
+
if self.deterministic:
|
| 41 |
+
return torch.Tensor([0.0])
|
| 42 |
+
else:
|
| 43 |
+
if other is None:
|
| 44 |
+
return 0.5 * torch.mean(
|
| 45 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 46 |
+
dim=[1, 2],
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
return 0.5 * torch.mean(
|
| 50 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 51 |
+
+ self.var / other.var
|
| 52 |
+
- 1.0
|
| 53 |
+
- self.logvar
|
| 54 |
+
+ other.logvar,
|
| 55 |
+
dim=[1, 2],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def nll(self, sample, dims=[1, 2]):
|
| 59 |
+
if self.deterministic:
|
| 60 |
+
return torch.Tensor([0.0])
|
| 61 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 62 |
+
return 0.5 * torch.sum(
|
| 63 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 64 |
+
dim=dims,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def mode(self):
|
| 68 |
+
return self.mean
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
| 72 |
+
"""
|
| 73 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
| 74 |
+
Compute the KL divergence between two gaussians.
|
| 75 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
| 76 |
+
scalars, among other use cases.
|
| 77 |
+
"""
|
| 78 |
+
tensor = None
|
| 79 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
| 80 |
+
if isinstance(obj, torch.Tensor):
|
| 81 |
+
tensor = obj
|
| 82 |
+
break
|
| 83 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
| 84 |
+
|
| 85 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
| 86 |
+
# Tensors, but it does not work for torch.exp().
|
| 87 |
+
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
| 88 |
+
|
| 89 |
+
return 0.5 * (
|
| 90 |
+
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
| 91 |
+
)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/__init__.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import argbind
|
| 4 |
+
from audiotools import ml
|
| 5 |
+
|
| 6 |
+
from ..model import DAC
|
| 7 |
+
Accelerator = ml.Accelerator
|
| 8 |
+
|
| 9 |
+
__MODEL_LATEST_TAGS__ = {
|
| 10 |
+
("44khz", "8kbps"): "0.0.1",
|
| 11 |
+
("24khz", "8kbps"): "0.0.4",
|
| 12 |
+
("16khz", "8kbps"): "0.0.5",
|
| 13 |
+
("44khz", "16kbps"): "1.0.0",
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
__MODEL_URLS__ = {
|
| 17 |
+
(
|
| 18 |
+
"44khz",
|
| 19 |
+
"0.0.1",
|
| 20 |
+
"8kbps",
|
| 21 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
|
| 22 |
+
(
|
| 23 |
+
"24khz",
|
| 24 |
+
"0.0.4",
|
| 25 |
+
"8kbps",
|
| 26 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
|
| 27 |
+
(
|
| 28 |
+
"16khz",
|
| 29 |
+
"0.0.5",
|
| 30 |
+
"8kbps",
|
| 31 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
|
| 32 |
+
(
|
| 33 |
+
"44khz",
|
| 34 |
+
"1.0.0",
|
| 35 |
+
"16kbps",
|
| 36 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@argbind.bind(group="download", positional=True, without_prefix=True)
|
| 41 |
+
def download(
|
| 42 |
+
model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Function that downloads the weights file from URL if a local cache is not found.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
model_type : str
|
| 50 |
+
The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
|
| 51 |
+
model_bitrate: str
|
| 52 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
| 53 |
+
Only 44khz model supports 16kbps.
|
| 54 |
+
tag : str
|
| 55 |
+
The tag of the model to download. Defaults to "latest".
|
| 56 |
+
|
| 57 |
+
Returns
|
| 58 |
+
-------
|
| 59 |
+
Path
|
| 60 |
+
Directory path required to load model via audiotools.
|
| 61 |
+
"""
|
| 62 |
+
model_type = model_type.lower()
|
| 63 |
+
tag = tag.lower()
|
| 64 |
+
|
| 65 |
+
assert model_type in [
|
| 66 |
+
"44khz",
|
| 67 |
+
"24khz",
|
| 68 |
+
"16khz",
|
| 69 |
+
], "model_type must be one of '44khz', '24khz', or '16khz'"
|
| 70 |
+
|
| 71 |
+
assert model_bitrate in [
|
| 72 |
+
"8kbps",
|
| 73 |
+
"16kbps",
|
| 74 |
+
], "model_bitrate must be one of '8kbps', or '16kbps'"
|
| 75 |
+
|
| 76 |
+
if tag == "latest":
|
| 77 |
+
tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
|
| 78 |
+
|
| 79 |
+
download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
|
| 80 |
+
|
| 81 |
+
if download_link is None:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Could not find model with tag {tag} and model type {model_type}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
local_path = (
|
| 87 |
+
Path.home()
|
| 88 |
+
/ ".cache"
|
| 89 |
+
/ "descript"
|
| 90 |
+
/ "dac"
|
| 91 |
+
/ f"weights_{model_type}_{model_bitrate}_{tag}.pth"
|
| 92 |
+
)
|
| 93 |
+
if not local_path.exists():
|
| 94 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
# Download the model
|
| 97 |
+
import requests
|
| 98 |
+
|
| 99 |
+
response = requests.get(download_link)
|
| 100 |
+
|
| 101 |
+
if response.status_code != 200:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"Could not download model. Received response code {response.status_code}"
|
| 104 |
+
)
|
| 105 |
+
local_path.write_bytes(response.content)
|
| 106 |
+
|
| 107 |
+
return local_path
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_model(
|
| 111 |
+
model_type: str = "44khz",
|
| 112 |
+
model_bitrate: str = "8kbps",
|
| 113 |
+
tag: str = "latest",
|
| 114 |
+
load_path: str = None,
|
| 115 |
+
):
|
| 116 |
+
if not load_path:
|
| 117 |
+
load_path = download(
|
| 118 |
+
model_type=model_type, model_bitrate=model_bitrate, tag=tag
|
| 119 |
+
)
|
| 120 |
+
generator = DAC.load(load_path)
|
| 121 |
+
return generator
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/decode.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import argbind
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from ..model import DACFile
|
| 11 |
+
from . import load_model
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@argbind.bind(group="decode", positional=True, without_prefix=True)
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def decode(
|
| 20 |
+
input: str,
|
| 21 |
+
output: str = "",
|
| 22 |
+
weights_path: str = "",
|
| 23 |
+
model_tag: str = "latest",
|
| 24 |
+
model_bitrate: str = "8kbps",
|
| 25 |
+
device: str = "cuda",
|
| 26 |
+
model_type: str = "44khz",
|
| 27 |
+
verbose: bool = False,
|
| 28 |
+
):
|
| 29 |
+
"""Decode audio from codes.
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
input : str
|
| 34 |
+
Path to input directory or file
|
| 35 |
+
output : str, optional
|
| 36 |
+
Path to output directory, by default "".
|
| 37 |
+
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
| 38 |
+
weights_path : str, optional
|
| 39 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
| 40 |
+
model_tag and model_type.
|
| 41 |
+
model_tag : str, optional
|
| 42 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
| 43 |
+
model_bitrate: str
|
| 44 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
| 45 |
+
device : str, optional
|
| 46 |
+
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
|
| 47 |
+
model_type : str, optional
|
| 48 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
| 49 |
+
"""
|
| 50 |
+
generator = load_model(
|
| 51 |
+
model_type=model_type,
|
| 52 |
+
model_bitrate=model_bitrate,
|
| 53 |
+
tag=model_tag,
|
| 54 |
+
load_path=weights_path,
|
| 55 |
+
)
|
| 56 |
+
generator.to(device)
|
| 57 |
+
generator.eval()
|
| 58 |
+
|
| 59 |
+
# Find all .dac files in input directory
|
| 60 |
+
_input = Path(input)
|
| 61 |
+
input_files = list(_input.glob("**/*.dac"))
|
| 62 |
+
|
| 63 |
+
# If input is a .dac file, add it to the list
|
| 64 |
+
if _input.suffix == ".dac":
|
| 65 |
+
input_files.append(_input)
|
| 66 |
+
|
| 67 |
+
# Create output directory
|
| 68 |
+
output = Path(output)
|
| 69 |
+
output.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
|
| 72 |
+
# Load file
|
| 73 |
+
artifact = DACFile.load(input_files[i])
|
| 74 |
+
|
| 75 |
+
# Reconstruct audio from codes
|
| 76 |
+
recons = generator.decompress(artifact, verbose=verbose)
|
| 77 |
+
|
| 78 |
+
# Compute output path
|
| 79 |
+
relative_path = input_files[i].relative_to(input)
|
| 80 |
+
output_dir = output / relative_path.parent
|
| 81 |
+
if not relative_path.name:
|
| 82 |
+
output_dir = output
|
| 83 |
+
relative_path = input_files[i]
|
| 84 |
+
output_name = relative_path.with_suffix(".wav").name
|
| 85 |
+
output_path = output_dir / output_name
|
| 86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
# Write to file
|
| 89 |
+
recons.write(output_path)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
args = argbind.parse_args()
|
| 94 |
+
with argbind.scope(args):
|
| 95 |
+
decode()
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/encode.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import argbind
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from audiotools import AudioSignal
|
| 9 |
+
from audiotools.core import util
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from . import load_model
|
| 13 |
+
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@argbind.bind(group="encode", positional=True, without_prefix=True)
|
| 18 |
+
@torch.inference_mode()
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def encode(
|
| 21 |
+
input: str,
|
| 22 |
+
output: str = "",
|
| 23 |
+
weights_path: str = "",
|
| 24 |
+
model_tag: str = "latest",
|
| 25 |
+
model_bitrate: str = "8kbps",
|
| 26 |
+
n_quantizers: int = None,
|
| 27 |
+
device: str = "cuda",
|
| 28 |
+
model_type: str = "44khz",
|
| 29 |
+
win_duration: float = 5.0,
|
| 30 |
+
verbose: bool = False,
|
| 31 |
+
):
|
| 32 |
+
"""Encode audio files in input path to .dac format.
|
| 33 |
+
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
input : str
|
| 37 |
+
Path to input audio file or directory
|
| 38 |
+
output : str, optional
|
| 39 |
+
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
| 40 |
+
weights_path : str, optional
|
| 41 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
| 42 |
+
model_tag and model_type.
|
| 43 |
+
model_tag : str, optional
|
| 44 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
| 45 |
+
model_bitrate: str
|
| 46 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
| 47 |
+
n_quantizers : int, optional
|
| 48 |
+
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
|
| 49 |
+
device : str, optional
|
| 50 |
+
Device to use, by default "cuda"
|
| 51 |
+
model_type : str, optional
|
| 52 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
| 53 |
+
"""
|
| 54 |
+
generator = load_model(
|
| 55 |
+
model_type=model_type,
|
| 56 |
+
model_bitrate=model_bitrate,
|
| 57 |
+
tag=model_tag,
|
| 58 |
+
load_path=weights_path,
|
| 59 |
+
)
|
| 60 |
+
generator.to(device)
|
| 61 |
+
generator.eval()
|
| 62 |
+
kwargs = {"n_quantizers": n_quantizers}
|
| 63 |
+
|
| 64 |
+
# Find all audio files in input path
|
| 65 |
+
input = Path(input)
|
| 66 |
+
audio_files = util.find_audio(input)
|
| 67 |
+
|
| 68 |
+
output = Path(output)
|
| 69 |
+
output.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
|
| 72 |
+
# Load file
|
| 73 |
+
signal = AudioSignal(audio_files[i])
|
| 74 |
+
|
| 75 |
+
# Encode audio to .dac format
|
| 76 |
+
artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
|
| 77 |
+
|
| 78 |
+
# Compute output path
|
| 79 |
+
relative_path = audio_files[i].relative_to(input)
|
| 80 |
+
output_dir = output / relative_path.parent
|
| 81 |
+
if not relative_path.name:
|
| 82 |
+
output_dir = output
|
| 83 |
+
relative_path = audio_files[i]
|
| 84 |
+
output_name = relative_path.with_suffix(".dac").name
|
| 85 |
+
output_path = output_dir / output_name
|
| 86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
artifact.save(output_path)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
args = argbind.parse_args()
|
| 93 |
+
with argbind.scope(args):
|
| 94 |
+
encode()
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/hifi_foley.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Optional, Union, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from einops.layers.torch import Rearrange
|
| 8 |
+
from diffusers.models import ModelMixin
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 10 |
+
|
| 11 |
+
from .nn.activation_layers import SwiGLU, get_activation_layer
|
| 12 |
+
from .nn.attn_layers import apply_rotary_emb, attention
|
| 13 |
+
from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
|
| 14 |
+
from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
|
| 15 |
+
from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
|
| 16 |
+
from .nn.norm_layers import get_norm_layer
|
| 17 |
+
from .nn.posemb_layers import get_nd_rotary_pos_embed
|
| 18 |
+
|
| 19 |
+
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
|
| 20 |
+
# [B, N1, H, C] & [B, N2, H, C]
|
| 21 |
+
B, N1, H, C = x1.shape
|
| 22 |
+
B, N2, H, C = x2.shape
|
| 23 |
+
assert x1.ndim == x2.ndim == 4
|
| 24 |
+
|
| 25 |
+
if N1 != N2:
|
| 26 |
+
x2 = x2.view(B, N2, -1).transpose(1, 2)
|
| 27 |
+
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
|
| 28 |
+
x2 = x2.transpose(1, 2).view(B, N1, H, C)
|
| 29 |
+
x = torch.stack((x1, x2), dim=2)
|
| 30 |
+
x = x.reshape(B, N1 * 2, H, C)
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
|
| 34 |
+
B, N, H, C = x.shape
|
| 35 |
+
assert N % 2 == 0 and N // 2 == len1
|
| 36 |
+
|
| 37 |
+
x = x.reshape(B, -1, 2, H, C)
|
| 38 |
+
x1 = x[:, :, 0]
|
| 39 |
+
x2 = x[:, :, 1]
|
| 40 |
+
if x2.shape[1] != len2:
|
| 41 |
+
x2 = x2.view(B, len1, H * C).transpose(1, 2)
|
| 42 |
+
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
|
| 43 |
+
x2 = x2.transpose(1, 2).view(B, len2, H, C)
|
| 44 |
+
return x1, x2
|
| 45 |
+
|
| 46 |
+
class TwoStreamCABlock(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
hidden_size: int,
|
| 50 |
+
num_heads: int,
|
| 51 |
+
mlp_ratio: float,
|
| 52 |
+
mlp_act_type: str = "gelu_tanh",
|
| 53 |
+
qk_norm: bool = True,
|
| 54 |
+
qk_norm_type: str = "rms",
|
| 55 |
+
qkv_bias: bool = False,
|
| 56 |
+
attn_mode: str = "torch",
|
| 57 |
+
reverse: bool = False,
|
| 58 |
+
interleaved_audio_visual_rope: bool = False,
|
| 59 |
+
dtype: Optional[torch.dtype] = None,
|
| 60 |
+
device: Optional[torch.device] = None,
|
| 61 |
+
):
|
| 62 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
self.deterministic = False
|
| 66 |
+
self.reverse = reverse
|
| 67 |
+
self.attn_mode = attn_mode
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
self.hidden_size = hidden_size
|
| 70 |
+
head_dim = hidden_size // num_heads
|
| 71 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 72 |
+
|
| 73 |
+
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
|
| 74 |
+
|
| 75 |
+
# Self attention for audio + visual
|
| 76 |
+
self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
| 77 |
+
self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 78 |
+
self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
| 79 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 80 |
+
self.audio_self_q_norm = (
|
| 81 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 82 |
+
)
|
| 83 |
+
self.audio_self_k_norm = (
|
| 84 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 85 |
+
)
|
| 86 |
+
self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 87 |
+
|
| 88 |
+
# visual cond
|
| 89 |
+
self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
| 90 |
+
self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 91 |
+
self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
| 92 |
+
self.v_cond_attn_q_norm = (
|
| 93 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 94 |
+
)
|
| 95 |
+
self.v_cond_attn_k_norm = (
|
| 96 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 97 |
+
)
|
| 98 |
+
self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 99 |
+
|
| 100 |
+
self.max_text_len = 100
|
| 101 |
+
self.rope_dim_list = None
|
| 102 |
+
|
| 103 |
+
# audio and video norm for cross attention with text
|
| 104 |
+
self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 105 |
+
self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 106 |
+
|
| 107 |
+
# Cross attention: (video_audio) as query, text as key/value
|
| 108 |
+
self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 109 |
+
self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 110 |
+
self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
|
| 111 |
+
|
| 112 |
+
self.audio_cross_q_norm = (
|
| 113 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 114 |
+
)
|
| 115 |
+
self.v_cond_cross_q_norm = (
|
| 116 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 117 |
+
)
|
| 118 |
+
self.text_cross_k_norm = (
|
| 119 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 120 |
+
)
|
| 121 |
+
self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 122 |
+
self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 123 |
+
|
| 124 |
+
# MLPs
|
| 125 |
+
self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 126 |
+
self.audio_mlp = MLP(
|
| 127 |
+
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 131 |
+
self.v_cond_mlp = MLP(
|
| 132 |
+
hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
|
| 136 |
+
target_ndim = 1 # n-d RoPE
|
| 137 |
+
rope_sizes = [text_len]
|
| 138 |
+
|
| 139 |
+
if rope_dim_list is None:
|
| 140 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| 141 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
| 142 |
+
|
| 143 |
+
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
|
| 144 |
+
rope_dim_list=rope_dim_list,
|
| 145 |
+
start=rope_sizes,
|
| 146 |
+
theta=10000,
|
| 147 |
+
use_real=True,
|
| 148 |
+
theta_rescale_factor=1.0,
|
| 149 |
+
)
|
| 150 |
+
return text_freqs_cos, text_freqs_sin
|
| 151 |
+
|
| 152 |
+
def set_attn_mode(self, new_mode):
|
| 153 |
+
if new_mode != "torch":
|
| 154 |
+
raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
|
| 155 |
+
self.attn_mode = new_mode
|
| 156 |
+
|
| 157 |
+
def enable_deterministic(self):
|
| 158 |
+
self.deterministic = True
|
| 159 |
+
|
| 160 |
+
def disable_deterministic(self):
|
| 161 |
+
self.deterministic = False
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
audio: torch.Tensor,
|
| 166 |
+
cond: torch.Tensor,
|
| 167 |
+
v_cond: torch.Tensor,
|
| 168 |
+
attn_mask: torch.Tensor,
|
| 169 |
+
vec: torch.Tensor,
|
| 170 |
+
freqs_cis: tuple = None,
|
| 171 |
+
v_freqs_cis: tuple = None,
|
| 172 |
+
sync_vec: torch.Tensor = None,
|
| 173 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 174 |
+
# Get modulation parameters
|
| 175 |
+
if sync_vec is not None:
|
| 176 |
+
assert sync_vec.ndim == 3
|
| 177 |
+
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
|
| 178 |
+
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
|
| 179 |
+
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
|
| 180 |
+
) = self.audio_mod(sync_vec).chunk(9, dim=-1)
|
| 181 |
+
else:
|
| 182 |
+
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
|
| 183 |
+
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
|
| 184 |
+
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
|
| 185 |
+
) = self.audio_mod(vec).chunk(9, dim=-1)
|
| 186 |
+
|
| 187 |
+
(
|
| 188 |
+
v_cond_mod1_shift,
|
| 189 |
+
v_cond_mod1_scale,
|
| 190 |
+
v_cond_mod1_gate,
|
| 191 |
+
v_cond_mod2_shift,
|
| 192 |
+
v_cond_mod2_scale,
|
| 193 |
+
v_cond_mod2_gate,
|
| 194 |
+
v_cond_mod3_shift,
|
| 195 |
+
v_cond_mod3_scale,
|
| 196 |
+
v_cond_mod3_gate,
|
| 197 |
+
) = self.v_cond_mod(vec).chunk(9, dim=-1)
|
| 198 |
+
|
| 199 |
+
# 1. Self Attention for audio + visual
|
| 200 |
+
audio_modulated = self.audio_norm1(audio)
|
| 201 |
+
audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
|
| 202 |
+
audio_qkv = self.audio_self_attn_qkv(audio_modulated)
|
| 203 |
+
audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
| 204 |
+
audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
|
| 205 |
+
audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
|
| 206 |
+
|
| 207 |
+
# Prepare visual cond for attention
|
| 208 |
+
v_cond_modulated = self.v_cond_norm1(v_cond)
|
| 209 |
+
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
|
| 210 |
+
v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
|
| 211 |
+
v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
| 212 |
+
v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
|
| 213 |
+
v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
|
| 214 |
+
|
| 215 |
+
# Apply RoPE if needed for audio and visual
|
| 216 |
+
if freqs_cis is not None:
|
| 217 |
+
if not self.interleaved_audio_visual_rope:
|
| 218 |
+
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
|
| 219 |
+
audio_q, audio_k = audio_qq, audio_kk
|
| 220 |
+
else:
|
| 221 |
+
ori_audio_len = audio_q.shape[1]
|
| 222 |
+
ori_v_con_len = v_cond_q.shape[1]
|
| 223 |
+
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
|
| 224 |
+
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
|
| 225 |
+
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
|
| 226 |
+
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
|
| 227 |
+
)
|
| 228 |
+
audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
|
| 229 |
+
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
|
| 230 |
+
)
|
| 231 |
+
audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
|
| 232 |
+
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
|
| 233 |
+
)
|
| 234 |
+
audio_q, audio_k = audio_qq, audio_kk
|
| 235 |
+
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
| 236 |
+
|
| 237 |
+
# Apply RoPE to visual if needed and not interleaved
|
| 238 |
+
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
|
| 239 |
+
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
|
| 240 |
+
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
| 241 |
+
|
| 242 |
+
# Concatenate for self-attention
|
| 243 |
+
q = torch.cat((v_cond_q, audio_q), dim=1)
|
| 244 |
+
k = torch.cat((v_cond_k, audio_k), dim=1)
|
| 245 |
+
v = torch.cat((v_cond_v, audio_v), dim=1)
|
| 246 |
+
|
| 247 |
+
# Run self-attention
|
| 248 |
+
attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
|
| 249 |
+
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
| 250 |
+
|
| 251 |
+
# Apply self-attention output to audio and v_cond
|
| 252 |
+
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
|
| 253 |
+
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
|
| 254 |
+
|
| 255 |
+
# 2. Cross Attention: (v_cond, audio) as query, text as key/value
|
| 256 |
+
# audio, v_cond modulation
|
| 257 |
+
audio_modulated = self.audio_norm2(audio)
|
| 258 |
+
audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
|
| 259 |
+
v_cond_modulated = self.v_cond_norm2(v_cond)
|
| 260 |
+
v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
|
| 261 |
+
|
| 262 |
+
# Prepare audio query
|
| 263 |
+
audio_q = self.audio_cross_q(audio_modulated)
|
| 264 |
+
audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
|
| 265 |
+
audio_q = self.audio_cross_q_norm(audio_q)
|
| 266 |
+
|
| 267 |
+
# Prepare v_cond query
|
| 268 |
+
v_cond_q = self.v_cond_cross_q(v_cond_modulated)
|
| 269 |
+
v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
|
| 270 |
+
v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
|
| 271 |
+
|
| 272 |
+
# Prepare text key/value
|
| 273 |
+
text_kv = self.text_cross_kv(cond)
|
| 274 |
+
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
|
| 275 |
+
text_k = self.text_cross_k_norm(text_k).to(text_v)
|
| 276 |
+
|
| 277 |
+
# Apply RoPE to (v_cond, audio) query and text key if needed
|
| 278 |
+
head_dim = self.hidden_size // self.num_heads
|
| 279 |
+
audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
|
| 280 |
+
audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
|
| 281 |
+
audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
|
| 282 |
+
|
| 283 |
+
v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
|
| 284 |
+
v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
|
| 285 |
+
v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
|
| 286 |
+
|
| 287 |
+
text_len = text_k.shape[1]
|
| 288 |
+
|
| 289 |
+
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
|
| 290 |
+
rope_dim_list=self.rope_dim_list)
|
| 291 |
+
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
|
| 292 |
+
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
|
| 293 |
+
|
| 294 |
+
# Concat v_cond and audio for cross-attention
|
| 295 |
+
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
|
| 296 |
+
|
| 297 |
+
# Run cross-attention
|
| 298 |
+
cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
|
| 299 |
+
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
| 300 |
+
|
| 301 |
+
# Apply cross-attention output
|
| 302 |
+
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
|
| 303 |
+
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
|
| 304 |
+
|
| 305 |
+
# 3. Apply MLPs
|
| 306 |
+
audio = audio + apply_gate(
|
| 307 |
+
self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
|
| 308 |
+
gate=audio_mod3_gate,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Apply visual MLP
|
| 312 |
+
v_cond = v_cond + apply_gate(
|
| 313 |
+
self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
|
| 314 |
+
gate=v_cond_mod3_gate,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return audio, cond, v_cond
|
| 318 |
+
|
| 319 |
+
class SingleStreamBlock(nn.Module):
|
| 320 |
+
|
| 321 |
+
def __init__(self, hidden_size: int,
|
| 322 |
+
num_heads: int,
|
| 323 |
+
mlp_ratio: float,
|
| 324 |
+
qk_norm_type: str = "rms",
|
| 325 |
+
dtype: Optional[torch.dtype] = None,
|
| 326 |
+
device: Optional[torch.device] = None,):
|
| 327 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 328 |
+
super().__init__()
|
| 329 |
+
|
| 330 |
+
self.hidden_size = hidden_size
|
| 331 |
+
self.num_heads = num_heads
|
| 332 |
+
|
| 333 |
+
self.modulation = ModulateDiT(
|
| 334 |
+
hidden_size=hidden_size,
|
| 335 |
+
factor=6,
|
| 336 |
+
act_layer=get_activation_layer("silu"),
|
| 337 |
+
**factory_kwargs,
|
| 338 |
+
)
|
| 339 |
+
self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
|
| 340 |
+
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
|
| 341 |
+
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
|
| 342 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
| 343 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
| 344 |
+
self.q_norm = nn.RMSNorm(hidden_size // num_heads)
|
| 345 |
+
self.k_norm = nn.RMSNorm(hidden_size // num_heads)
|
| 346 |
+
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
|
| 347 |
+
|
| 348 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
|
| 349 |
+
assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
|
| 350 |
+
modulation = self.modulation(cond)
|
| 351 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
|
| 352 |
+
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
|
| 353 |
+
|
| 354 |
+
qkv = self.linear_qkv(x_norm1)
|
| 355 |
+
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
|
| 356 |
+
q = q.squeeze(-1)
|
| 357 |
+
k = k.squeeze(-1)
|
| 358 |
+
v = v.squeeze(-1)
|
| 359 |
+
|
| 360 |
+
q = self.q_norm(q)
|
| 361 |
+
k = self.k_norm(k)
|
| 362 |
+
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
|
| 363 |
+
|
| 364 |
+
q = q.contiguous()
|
| 365 |
+
k = k.contiguous()
|
| 366 |
+
v = v.contiguous()
|
| 367 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 368 |
+
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
|
| 369 |
+
|
| 370 |
+
x = x + apply_gate(self.linear1(out),gate=gate_msa)
|
| 371 |
+
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
|
| 372 |
+
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
|
| 373 |
+
|
| 374 |
+
return x
|
| 375 |
+
|
| 376 |
+
class HunyuanVideoFoley(ModelMixin, ConfigMixin):
|
| 377 |
+
@register_to_config
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
model_config,
|
| 381 |
+
dtype: Optional[torch.dtype] = None,
|
| 382 |
+
device: Optional[torch.device] = None,
|
| 383 |
+
):
|
| 384 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 385 |
+
super().__init__()
|
| 386 |
+
|
| 387 |
+
model_args = model_config.model_config.model_kwargs
|
| 388 |
+
self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
|
| 389 |
+
self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
|
| 390 |
+
# Gradient checkpoint.
|
| 391 |
+
self.gradient_checkpoint = False
|
| 392 |
+
self.gradient_checkpoint_layers = None
|
| 393 |
+
if self.gradient_checkpoint:
|
| 394 |
+
assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
|
| 395 |
+
f"Gradient checkpoint layers must be less or equal than the depth of the model. "
|
| 396 |
+
f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
|
| 400 |
+
|
| 401 |
+
# Condition projection. Default to linear projection.
|
| 402 |
+
self.condition_projection = model_args.get("condition_projection", "linear")
|
| 403 |
+
self.condition_dim = model_args.get("condition_dim", None)
|
| 404 |
+
self.use_attention_mask = model_args.get("use_attention_mask", False)
|
| 405 |
+
|
| 406 |
+
self.patch_size = model_args.get("patch_size", 1)
|
| 407 |
+
self.visual_in_channels = model_args.get("clip_dim", 768)
|
| 408 |
+
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
|
| 409 |
+
self.out_channels = self.audio_vae_latent_dim
|
| 410 |
+
self.unpatchify_channels = self.out_channels
|
| 411 |
+
self.reverse = model_args.get("reverse", False)
|
| 412 |
+
|
| 413 |
+
self.num_heads = model_args.get("num_heads", 24)
|
| 414 |
+
self.hidden_size = model_args.get("hidden_size", 3072)
|
| 415 |
+
self.rope_dim_list = model_args.get("rope_dim_list", None)
|
| 416 |
+
self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
|
| 417 |
+
self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
|
| 418 |
+
|
| 419 |
+
self.qkv_bias = model_args.get("qkv_bias", True)
|
| 420 |
+
self.qk_norm = model_args.get("qk_norm", True)
|
| 421 |
+
self.qk_norm_type = model_args.get("qk_norm_type", "rms")
|
| 422 |
+
self.attn_mode = model_args.get("attn_mode", "torch")
|
| 423 |
+
|
| 424 |
+
self.embedder_type = model_args.get("embedder_type", "default")
|
| 425 |
+
|
| 426 |
+
# sync condition things
|
| 427 |
+
self.sync_modulation = model_args.get("sync_modulation", False)
|
| 428 |
+
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
|
| 429 |
+
self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
|
| 430 |
+
self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
|
| 431 |
+
|
| 432 |
+
# condition tokens length
|
| 433 |
+
self.clip_len = model_args.get("clip_length", 64)
|
| 434 |
+
self.sync_len = model_args.get("sync_length", 192)
|
| 435 |
+
|
| 436 |
+
if self.hidden_size % self.num_heads != 0:
|
| 437 |
+
raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
|
| 438 |
+
|
| 439 |
+
# Build audio patchify layer and visual gated linear projection
|
| 440 |
+
self.patch_size = 1
|
| 441 |
+
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
|
| 442 |
+
self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
|
| 443 |
+
|
| 444 |
+
# condition
|
| 445 |
+
if self.condition_projection == "linear":
|
| 446 |
+
self.cond_in = ConditionProjection(
|
| 447 |
+
self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
|
| 451 |
+
|
| 452 |
+
# time modulation
|
| 453 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
| 454 |
+
|
| 455 |
+
# visual sync embedder if needed
|
| 456 |
+
if self.sync_in_ksz == 1:
|
| 457 |
+
sync_in_padding = 0
|
| 458 |
+
elif self.sync_in_ksz == 3:
|
| 459 |
+
sync_in_padding = 1
|
| 460 |
+
else:
|
| 461 |
+
raise ValueError
|
| 462 |
+
if self.sync_modulation or self.add_sync_feat_to_audio:
|
| 463 |
+
self.sync_in = nn.Sequential(
|
| 464 |
+
nn.Linear(self.sync_feat_dim, self.hidden_size),
|
| 465 |
+
nn.SiLU(),
|
| 466 |
+
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
|
| 467 |
+
)
|
| 468 |
+
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
|
| 469 |
+
|
| 470 |
+
self.triple_blocks = nn.ModuleList(
|
| 471 |
+
[
|
| 472 |
+
TwoStreamCABlock(
|
| 473 |
+
hidden_size=self.hidden_size,
|
| 474 |
+
num_heads=self.num_heads,
|
| 475 |
+
mlp_ratio=self.mlp_ratio,
|
| 476 |
+
mlp_act_type=self.mlp_act_type,
|
| 477 |
+
qk_norm=self.qk_norm,
|
| 478 |
+
qk_norm_type=self.qk_norm_type,
|
| 479 |
+
qkv_bias=self.qkv_bias,
|
| 480 |
+
attn_mode=self.attn_mode,
|
| 481 |
+
reverse=self.reverse,
|
| 482 |
+
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
|
| 483 |
+
**factory_kwargs,
|
| 484 |
+
)
|
| 485 |
+
for _ in range(self.depth_triple_blocks)
|
| 486 |
+
]
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
self.single_blocks = nn.ModuleList(
|
| 491 |
+
[
|
| 492 |
+
SingleStreamBlock(
|
| 493 |
+
hidden_size=self.hidden_size,
|
| 494 |
+
num_heads=self.num_heads,
|
| 495 |
+
mlp_ratio=self.mlp_ratio,
|
| 496 |
+
qk_norm_type=self.qk_norm_type,
|
| 497 |
+
**factory_kwargs,
|
| 498 |
+
)
|
| 499 |
+
for _ in range(self.depth_single_blocks)
|
| 500 |
+
]
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
self.final_layer = FinalLayer1D(
|
| 504 |
+
self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
|
| 505 |
+
)
|
| 506 |
+
self.unpatchify_channels = self.out_channels
|
| 507 |
+
|
| 508 |
+
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
|
| 509 |
+
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
|
| 510 |
+
nn.init.constant_(self.empty_clip_feat, 0)
|
| 511 |
+
nn.init.constant_(self.empty_sync_feat, 0)
|
| 512 |
+
|
| 513 |
+
def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
|
| 514 |
+
if bs is None:
|
| 515 |
+
return self.empty_string_feat
|
| 516 |
+
else:
|
| 517 |
+
return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
|
| 518 |
+
|
| 519 |
+
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
|
| 520 |
+
len = len if len is not None else self.clip_len
|
| 521 |
+
if bs is None:
|
| 522 |
+
return self.empty_clip_feat.expand(len, -1) # 15s
|
| 523 |
+
else:
|
| 524 |
+
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
|
| 525 |
+
|
| 526 |
+
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
|
| 527 |
+
len = len if len is not None else self.sync_len
|
| 528 |
+
if bs is None:
|
| 529 |
+
return self.empty_sync_feat.expand(len, -1)
|
| 530 |
+
else:
|
| 531 |
+
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
|
| 532 |
+
|
| 533 |
+
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
|
| 534 |
+
assert self.patch_size == 1
|
| 535 |
+
# ======================================== Build RoPE for audio tokens ======================================
|
| 536 |
+
target_ndim = 1 # n-d RoPE
|
| 537 |
+
rope_sizes = [audio_emb_len]
|
| 538 |
+
head_dim = self.hidden_size // self.num_heads
|
| 539 |
+
rope_dim_list = self.rope_dim_list
|
| 540 |
+
if rope_dim_list is None:
|
| 541 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| 542 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
| 543 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
| 544 |
+
rope_dim_list=rope_dim_list,
|
| 545 |
+
start=rope_sizes,
|
| 546 |
+
theta=10000,
|
| 547 |
+
use_real=True,
|
| 548 |
+
theta_rescale_factor=1.0,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# ========================== Build RoPE for clip tokens =========================
|
| 552 |
+
target_ndim = 1 # n-d RoPE
|
| 553 |
+
rope_sizes = [visual_cond_len]
|
| 554 |
+
head_dim = self.hidden_size // self.num_heads
|
| 555 |
+
rope_dim_list = self.rope_dim_list
|
| 556 |
+
if rope_dim_list is None:
|
| 557 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| 558 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
| 559 |
+
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
|
| 560 |
+
rope_dim_list=rope_dim_list,
|
| 561 |
+
start=rope_sizes,
|
| 562 |
+
theta=10000,
|
| 563 |
+
use_real=True,
|
| 564 |
+
theta_rescale_factor=1.0,
|
| 565 |
+
freq_scaling=1.0 * audio_emb_len / visual_cond_len,
|
| 566 |
+
)
|
| 567 |
+
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
|
| 568 |
+
|
| 569 |
+
def build_rope_for_interleaved_audio_visual(self, total_len):
|
| 570 |
+
assert self.patch_size == 1
|
| 571 |
+
# ========================== Build RoPE for audio tokens ========================
|
| 572 |
+
target_ndim = 1 # n-d RoPE
|
| 573 |
+
rope_sizes = [total_len]
|
| 574 |
+
head_dim = self.hidden_size // self.num_heads
|
| 575 |
+
rope_dim_list = self.rope_dim_list
|
| 576 |
+
if rope_dim_list is None:
|
| 577 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| 578 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
| 579 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
| 580 |
+
rope_dim_list=rope_dim_list,
|
| 581 |
+
start=rope_sizes,
|
| 582 |
+
theta=10000,
|
| 583 |
+
use_real=True,
|
| 584 |
+
theta_rescale_factor=1.0,
|
| 585 |
+
)
|
| 586 |
+
return freqs_cos, freqs_sin
|
| 587 |
+
|
| 588 |
+
def set_attn_mode(self, new_mode):
|
| 589 |
+
for block in self.triple_blocks:
|
| 590 |
+
block.set_attn_mode(new_mode)
|
| 591 |
+
for block in self.single_blocks:
|
| 592 |
+
block.set_attn_mode(new_mode)
|
| 593 |
+
|
| 594 |
+
def enable_deterministic(self):
|
| 595 |
+
for block in self.triple_blocks:
|
| 596 |
+
block.enable_deterministic()
|
| 597 |
+
for block in self.single_blocks:
|
| 598 |
+
block.enable_deterministic()
|
| 599 |
+
|
| 600 |
+
def disable_deterministic(self):
|
| 601 |
+
for block in self.triple_blocks:
|
| 602 |
+
block.disable_deterministic()
|
| 603 |
+
for block in self.single_blocks:
|
| 604 |
+
block.disable_deterministic()
|
| 605 |
+
|
| 606 |
+
def forward(
|
| 607 |
+
self,
|
| 608 |
+
x: torch.Tensor,
|
| 609 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
| 610 |
+
clip_feat: Optional[torch.Tensor] = None,
|
| 611 |
+
cond: torch.Tensor = None,
|
| 612 |
+
audio_mask: Optional[torch.Tensor] = None,
|
| 613 |
+
cond_mask: torch.Tensor = None,
|
| 614 |
+
sync_feat: Optional[torch.Tensor] = None,
|
| 615 |
+
drop_visual: Optional[List[bool]] = None,
|
| 616 |
+
return_dict: bool = True,
|
| 617 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 618 |
+
out = {}
|
| 619 |
+
audio = x
|
| 620 |
+
bs, _, ol = x.shape
|
| 621 |
+
tl = ol // self.patch_size
|
| 622 |
+
|
| 623 |
+
# Prepare learnable empty conditions for visual condition
|
| 624 |
+
if drop_visual is not None:
|
| 625 |
+
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
| 626 |
+
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
|
| 627 |
+
|
| 628 |
+
# ========================= Prepare time & visual modulation =========================
|
| 629 |
+
vec = self.time_in(t)
|
| 630 |
+
sync_vec = None
|
| 631 |
+
if self.sync_modulation:
|
| 632 |
+
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
|
| 633 |
+
sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
|
| 634 |
+
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
|
| 635 |
+
sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
|
| 636 |
+
sync_vec = (
|
| 637 |
+
F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
|
| 638 |
+
) # bs, tl, c
|
| 639 |
+
sync_vec = sync_vec + vec.unsqueeze(1)
|
| 640 |
+
elif self.add_sync_feat_to_audio:
|
| 641 |
+
assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
|
| 642 |
+
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
|
| 643 |
+
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
|
| 644 |
+
sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
|
| 645 |
+
add_sync_feat_to_audio = (
|
| 646 |
+
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
|
| 647 |
+
) # bs, tl, c
|
| 648 |
+
|
| 649 |
+
# ========================= Get text, audio and video clip embedding =========================
|
| 650 |
+
cond = self.cond_in(cond)
|
| 651 |
+
cond_seq_len = cond.shape[1]
|
| 652 |
+
|
| 653 |
+
audio = self.audio_embedder(x)
|
| 654 |
+
audio_seq_len = audio.shape[1]
|
| 655 |
+
v_cond = self.visual_proj(clip_feat)
|
| 656 |
+
v_cond_seq_len = v_cond.shape[1]
|
| 657 |
+
|
| 658 |
+
# ========================= Compute attention mask =========================
|
| 659 |
+
attn_mask = None
|
| 660 |
+
if self.use_attention_mask:
|
| 661 |
+
assert cond_mask is not None
|
| 662 |
+
batch_size = audio.shape[0]
|
| 663 |
+
seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
|
| 664 |
+
|
| 665 |
+
# get default audio_mask and v_cond_mask
|
| 666 |
+
audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
|
| 667 |
+
v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
|
| 668 |
+
|
| 669 |
+
# batch_size x seq_len
|
| 670 |
+
concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
|
| 671 |
+
# batch_size x 1 x seq_len x seq_len
|
| 672 |
+
attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 673 |
+
# batch_size x 1 x seq_len x seq_len
|
| 674 |
+
attn_mask_2 = attn_mask_1.transpose(2, 3)
|
| 675 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
|
| 676 |
+
attn_mask = (attn_mask_1 & attn_mask_2).bool()
|
| 677 |
+
# avoids self-attention weight being NaN for text padding tokens
|
| 678 |
+
attn_mask[:, :, :, 0] = True
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# ========================= Build rope for audio and clip tokens =========================
|
| 682 |
+
if self.interleaved_audio_visual_rope:
|
| 683 |
+
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
|
| 684 |
+
v_freqs_cos = v_freqs_sin = None
|
| 685 |
+
else:
|
| 686 |
+
freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
|
| 687 |
+
audio_seq_len, v_cond_seq_len
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
# ========================= Pass through DiT blocks =========================
|
| 691 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
| 692 |
+
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
|
| 693 |
+
|
| 694 |
+
if self.add_sync_feat_to_audio:
|
| 695 |
+
add_sync_layer = 0
|
| 696 |
+
assert (
|
| 697 |
+
add_sync_layer < self.depth_triple_blocks
|
| 698 |
+
), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
|
| 699 |
+
# Triple-stream blocks
|
| 700 |
+
for layer_num, block in enumerate(self.triple_blocks):
|
| 701 |
+
if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
|
| 702 |
+
audio = audio + add_sync_feat_to_audio
|
| 703 |
+
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
|
| 704 |
+
if (
|
| 705 |
+
self.training
|
| 706 |
+
and self.gradient_checkpoint
|
| 707 |
+
and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
|
| 708 |
+
):
|
| 709 |
+
audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
|
| 710 |
+
ckpt_wrapper(block), *triple_block_args, use_reentrant=False
|
| 711 |
+
)
|
| 712 |
+
else:
|
| 713 |
+
audio, cond, v_cond = block(*triple_block_args)
|
| 714 |
+
|
| 715 |
+
x = audio
|
| 716 |
+
if sync_vec is not None:
|
| 717 |
+
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
|
| 718 |
+
vec = torch.cat((vec, sync_vec), dim=1)
|
| 719 |
+
|
| 720 |
+
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
|
| 721 |
+
if self.add_sync_feat_to_audio:
|
| 722 |
+
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
|
| 723 |
+
if len(self.single_blocks) > 0:
|
| 724 |
+
for layer_num, block in enumerate(self.single_blocks):
|
| 725 |
+
single_block_args = [
|
| 726 |
+
x,
|
| 727 |
+
vec,
|
| 728 |
+
(freqs_cos, freqs_sin),
|
| 729 |
+
]
|
| 730 |
+
if (
|
| 731 |
+
self.training
|
| 732 |
+
and self.gradient_checkpoint
|
| 733 |
+
and (
|
| 734 |
+
self.gradient_checkpoint_layers == -1
|
| 735 |
+
or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
|
| 736 |
+
)
|
| 737 |
+
):
|
| 738 |
+
x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
|
| 739 |
+
else:
|
| 740 |
+
x = block(*single_block_args)
|
| 741 |
+
|
| 742 |
+
audio = x
|
| 743 |
+
|
| 744 |
+
# ========================= Final layer =========================
|
| 745 |
+
if sync_vec is not None:
|
| 746 |
+
vec = sync_vec
|
| 747 |
+
audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
|
| 748 |
+
audio = self.unpatchify1d(audio, tl)
|
| 749 |
+
|
| 750 |
+
if return_dict:
|
| 751 |
+
out["x"] = audio
|
| 752 |
+
return out
|
| 753 |
+
return audio
|
| 754 |
+
|
| 755 |
+
def unpatchify1d(self, x, l):
|
| 756 |
+
# x: (N, L, patch_size * C)
|
| 757 |
+
# audio: (N, C, T), T == L * patch_size
|
| 758 |
+
c = self.unpatchify_channels
|
| 759 |
+
p = self.patch_size
|
| 760 |
+
assert l == x.shape[1]
|
| 761 |
+
|
| 762 |
+
x = x.reshape(shape=(x.shape[0], l, p, c))
|
| 763 |
+
x = torch.einsum("ntpc->nctp", x)
|
| 764 |
+
audio = x.reshape(shape=(x.shape[0], c, l * p))
|
| 765 |
+
return audio
|
| 766 |
+
|
| 767 |
+
def params_count(self):
|
| 768 |
+
counts = {
|
| 769 |
+
"triple": sum(
|
| 770 |
+
[
|
| 771 |
+
sum(p.numel() for p in block.audio_cross_q.parameters())
|
| 772 |
+
+ sum(p.numel() for p in block.v_cond_cross_q.parameters())
|
| 773 |
+
+ sum(p.numel() for p in block.text_cross_kv.parameters())
|
| 774 |
+
+ sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
|
| 775 |
+
+ sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
|
| 776 |
+
+ sum(p.numel() for p in block.audio_mlp.parameters())
|
| 777 |
+
+ sum(p.numel() for p in block.audio_self_proj.parameters())
|
| 778 |
+
+ sum(p.numel() for p in block.v_cond_self_proj.parameters())
|
| 779 |
+
+ sum(p.numel() for p in block.v_cond_mlp.parameters())
|
| 780 |
+
for block in self.triple_blocks
|
| 781 |
+
]
|
| 782 |
+
),
|
| 783 |
+
"single": sum(
|
| 784 |
+
[
|
| 785 |
+
sum(p.numel() for p in block.linear1.parameters())
|
| 786 |
+
+ sum(p.numel() for p in block.linear2.parameters())
|
| 787 |
+
for block in self.single_blocks
|
| 788 |
+
]
|
| 789 |
+
),
|
| 790 |
+
"total": sum(p.numel() for p in self.parameters()),
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
counts["attn+mlp"] = counts["triple"] + counts["single"]
|
| 794 |
+
return counts
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/__init__.py
ADDED
|
File without changes
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/activation_layers.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
def get_activation_layer(act_type):
|
| 5 |
+
if act_type == "gelu":
|
| 6 |
+
return lambda: nn.GELU()
|
| 7 |
+
elif act_type == "gelu_tanh":
|
| 8 |
+
# Approximate `tanh` requires torch >= 1.13
|
| 9 |
+
return lambda: nn.GELU(approximate="tanh")
|
| 10 |
+
elif act_type == "relu":
|
| 11 |
+
return nn.ReLU
|
| 12 |
+
elif act_type == "silu":
|
| 13 |
+
return nn.SiLU
|
| 14 |
+
else:
|
| 15 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
| 16 |
+
|
| 17 |
+
class SwiGLU(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim: int,
|
| 21 |
+
hidden_dim: int,
|
| 22 |
+
out_dim: int,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the SwiGLU FeedForward module.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dim (int): Input dimension.
|
| 29 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
w1: Linear transformation for the first layer.
|
| 33 |
+
w2: Linear transformation for the second layer.
|
| 34 |
+
w3: Linear transformation for the third layer.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 40 |
+
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
|
| 41 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/attn_layers.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.metadata
|
| 2 |
+
import math
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from flash_attn import (
|
| 11 |
+
flash_attn_qkvpacked_func,
|
| 12 |
+
flash_attn_kvpacked_func,
|
| 13 |
+
flash_attn_varlen_kvpacked_func,
|
| 14 |
+
flash_attn_varlen_qkvpacked_func,
|
| 15 |
+
)
|
| 16 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
| 17 |
+
except ImportError:
|
| 18 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
|
| 19 |
+
index_first_axis = None
|
| 20 |
+
from packaging import version
|
| 21 |
+
from transformers.utils.import_utils import _is_package_available
|
| 22 |
+
|
| 23 |
+
from .norm_layers import get_norm_layer
|
| 24 |
+
|
| 25 |
+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
| 26 |
+
"""
|
| 27 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 28 |
+
|
| 29 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 30 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 31 |
+
|
| 32 |
+
Notes:
|
| 33 |
+
When using FlashMHAModified, head_first should be False.
|
| 34 |
+
When using Attention, head_first should be True.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
| 38 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 39 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
| 46 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
| 47 |
+
"""
|
| 48 |
+
ndim = x.ndim
|
| 49 |
+
assert 0 <= 1 < ndim
|
| 50 |
+
|
| 51 |
+
if isinstance(freqs_cis, tuple):
|
| 52 |
+
# freqs_cis: (cos, sin) in real space
|
| 53 |
+
if head_first:
|
| 54 |
+
assert freqs_cis[0].shape == (
|
| 55 |
+
x.shape[-2],
|
| 56 |
+
x.shape[-1],
|
| 57 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| 58 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 59 |
+
else:
|
| 60 |
+
assert freqs_cis[0].shape == (
|
| 61 |
+
x.shape[1],
|
| 62 |
+
x.shape[-1],
|
| 63 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| 64 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 65 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
| 66 |
+
else:
|
| 67 |
+
# freqs_cis: values in complex space
|
| 68 |
+
if head_first:
|
| 69 |
+
assert freqs_cis.shape == (
|
| 70 |
+
x.shape[-2],
|
| 71 |
+
x.shape[-1],
|
| 72 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| 73 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 74 |
+
else:
|
| 75 |
+
assert freqs_cis.shape == (
|
| 76 |
+
x.shape[1],
|
| 77 |
+
x.shape[-1],
|
| 78 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| 79 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 80 |
+
return freqs_cis.view(*shape)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def rotate_half(x):
|
| 84 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 85 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def apply_rotary_emb(
|
| 89 |
+
xq: torch.Tensor,
|
| 90 |
+
xk: torch.Tensor,
|
| 91 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| 92 |
+
head_first: bool = False,
|
| 93 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 94 |
+
"""
|
| 95 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 96 |
+
|
| 97 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| 98 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| 99 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| 100 |
+
returned as real tensors.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
| 104 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
| 105 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
| 106 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
xk_out = None
|
| 113 |
+
if isinstance(freqs_cis, tuple):
|
| 114 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
| 115 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
| 116 |
+
# real * cos - imag * sin
|
| 117 |
+
# imag * cos + real * sin
|
| 118 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
| 119 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
| 120 |
+
else:
|
| 121 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
| 122 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 123 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
| 124 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
| 125 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
| 126 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
| 127 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 128 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
| 129 |
+
|
| 130 |
+
return xq_out, xk_out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class BasicAttentionLayer(nn.Module):
|
| 134 |
+
def __init__(self, attn_mode="flash", deterministic=False):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.attn_mode = attn_mode
|
| 137 |
+
self.deterministic = deterministic
|
| 138 |
+
|
| 139 |
+
def set_attn_mode(self, new_mode):
|
| 140 |
+
self.attn_mode = new_mode
|
| 141 |
+
|
| 142 |
+
def enable_deterministic(self):
|
| 143 |
+
self.deterministic = True
|
| 144 |
+
|
| 145 |
+
def disable_deterministic(self):
|
| 146 |
+
self.deterministic = False
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
MEMORY_LAYOUT = {
|
| 150 |
+
"self_flash": (
|
| 151 |
+
lambda x: x,
|
| 152 |
+
lambda x: x,
|
| 153 |
+
),
|
| 154 |
+
"cross_flash": (
|
| 155 |
+
lambda x: x,
|
| 156 |
+
lambda x: x,
|
| 157 |
+
),
|
| 158 |
+
"flash_torch_sp": (
|
| 159 |
+
lambda x: x,
|
| 160 |
+
lambda x: x,
|
| 161 |
+
),
|
| 162 |
+
"torch": (
|
| 163 |
+
lambda x: x.transpose(1, 2),
|
| 164 |
+
lambda x: x.transpose(1, 2),
|
| 165 |
+
),
|
| 166 |
+
"vanilla": (
|
| 167 |
+
lambda x: x.transpose(1, 2),
|
| 168 |
+
lambda x: x.transpose(1, 2),
|
| 169 |
+
),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
|
| 174 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 175 |
+
"""
|
| 176 |
+
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
| 177 |
+
|
| 178 |
+
Arguments:
|
| 179 |
+
attention_mask (`torch.Tensor`):
|
| 180 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
| 181 |
+
|
| 182 |
+
Return:
|
| 183 |
+
indices (`torch.Tensor):
|
| 184 |
+
The indices of non-masked tokens from the flattened input sequence.
|
| 185 |
+
cu_seqlens (`torch.Tensor`):
|
| 186 |
+
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
| 187 |
+
max_seqlen_in_batch (`int`):
|
| 188 |
+
Maximum sequence length in batch.
|
| 189 |
+
"""
|
| 190 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 191 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 192 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 193 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 194 |
+
return (
|
| 195 |
+
indices,
|
| 196 |
+
cu_seqlens,
|
| 197 |
+
max_seqlen_in_batch,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
|
| 202 |
+
def is_flash_attn_greater_or_equal(library_version: str):
|
| 203 |
+
if not _is_package_available("flash_attn"):
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_kv_seqlens_with_mask(attn_mask, k, v):
|
| 210 |
+
indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
|
| 211 |
+
b, s1, a, d = k.shape
|
| 212 |
+
k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
|
| 213 |
+
v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
|
| 214 |
+
kv = torch.stack([k, v], dim=1)
|
| 215 |
+
return cu_seqlens_k, max_seqlen_k, kv
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_q_seqlens(q):
|
| 219 |
+
bs, s, a, d = q.shape
|
| 220 |
+
cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
|
| 221 |
+
q = q.reshape(bs * s, a, d)
|
| 222 |
+
return cu_seqlens_q, s, q
|
| 223 |
+
|
| 224 |
+
def flash_attn_no_pad(
|
| 225 |
+
qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
|
| 226 |
+
):
|
| 227 |
+
# adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
|
| 228 |
+
batch_size = qkv.shape[0]
|
| 229 |
+
seqlen = qkv.shape[1]
|
| 230 |
+
nheads = qkv.shape[-2]
|
| 231 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 232 |
+
# x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
|
| 233 |
+
# x_unpad, indices, cu_seqlens, max_s
|
| 234 |
+
unpad_results = unpad_input(
|
| 235 |
+
x, key_padding_mask
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if len(unpad_results) == 4:
|
| 239 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_results
|
| 240 |
+
elif len(unpad_results) == 5:
|
| 241 |
+
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError
|
| 244 |
+
|
| 245 |
+
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
| 246 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 247 |
+
x_unpad,
|
| 248 |
+
cu_seqlens,
|
| 249 |
+
max_s,
|
| 250 |
+
dropout_p,
|
| 251 |
+
softmax_scale=softmax_scale,
|
| 252 |
+
causal=causal,
|
| 253 |
+
)
|
| 254 |
+
output = rearrange(
|
| 255 |
+
pad_input(
|
| 256 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
|
| 257 |
+
),
|
| 258 |
+
"b s (h d) -> b s h d",
|
| 259 |
+
h=nheads,
|
| 260 |
+
)
|
| 261 |
+
return output
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def attention(
|
| 265 |
+
q,
|
| 266 |
+
k,
|
| 267 |
+
v,
|
| 268 |
+
mode,
|
| 269 |
+
drop_rate=0,
|
| 270 |
+
attn_mask=None,
|
| 271 |
+
cond_mask=None,
|
| 272 |
+
causal=False,
|
| 273 |
+
deterministic=False,
|
| 274 |
+
cu_seqlens=None,
|
| 275 |
+
max_seqlen=None,
|
| 276 |
+
cu_seqlens_k=None,
|
| 277 |
+
max_seqlen_k=None,
|
| 278 |
+
img_seq_len=None,
|
| 279 |
+
):
|
| 280 |
+
"""
|
| 281 |
+
Perform QKV self attention.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
| 285 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
| 286 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
| 287 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
| 288 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
| 289 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
| 290 |
+
(default: None)
|
| 291 |
+
causal (bool): Whether to use causal attention. (default: False)
|
| 292 |
+
deterministic (bool): Whether to use deterministic attention. (default: False)
|
| 293 |
+
cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 294 |
+
used to index into q.
|
| 295 |
+
max_seqlen (int): The maximum sequence length in the batch of q.
|
| 296 |
+
cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 297 |
+
used to index into kv.
|
| 298 |
+
max_seqlen_k (int): The maximum sequence length in the batch of k and v.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
| 302 |
+
"""
|
| 303 |
+
if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
|
| 304 |
+
if isinstance(q, tuple):
|
| 305 |
+
q = torch.cat(q, dim=1)
|
| 306 |
+
if isinstance(k, tuple):
|
| 307 |
+
k = torch.cat(k, dim=1)
|
| 308 |
+
if isinstance(v, tuple):
|
| 309 |
+
v = torch.cat(v, dim=1)
|
| 310 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
| 311 |
+
q = pre_attn_layout(q)
|
| 312 |
+
k = pre_attn_layout(k)
|
| 313 |
+
v = pre_attn_layout(v)
|
| 314 |
+
|
| 315 |
+
if "flash" in mode:
|
| 316 |
+
assert (
|
| 317 |
+
flash_attn_qkvpacked_func is not None
|
| 318 |
+
), "Flash attention is not available. Please install flash_attn first."
|
| 319 |
+
flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
|
| 320 |
+
if deterministic:
|
| 321 |
+
if not is_flash_attn_greater_or_equal("2.4.1"):
|
| 322 |
+
raise ValueError(
|
| 323 |
+
"Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
|
| 324 |
+
)
|
| 325 |
+
flash_kwargs["deterministic"] = deterministic
|
| 326 |
+
|
| 327 |
+
if mode == "self_flash":
|
| 328 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 329 |
+
if attn_mask is not None:
|
| 330 |
+
raise ValueError("Self attention does not support attention mask")
|
| 331 |
+
x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
|
| 332 |
+
|
| 333 |
+
elif mode == "cross_flash":
|
| 334 |
+
kv = torch.stack([k, v], dim=2)
|
| 335 |
+
if attn_mask is None:
|
| 336 |
+
x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
|
| 337 |
+
else:
|
| 338 |
+
b, s, a, h = q.shape
|
| 339 |
+
cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
|
| 340 |
+
cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
|
| 341 |
+
|
| 342 |
+
attn_output = flash_attn_varlen_kvpacked_func(
|
| 343 |
+
q,
|
| 344 |
+
kv,
|
| 345 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 346 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 347 |
+
max_seqlen_q=max_seqlen_q,
|
| 348 |
+
max_seqlen_k=max_seqlen_k,
|
| 349 |
+
**flash_kwargs,
|
| 350 |
+
)
|
| 351 |
+
x = attn_output.reshape(b, s, a, h)
|
| 352 |
+
elif mode == 'torch':
|
| 353 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
| 354 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 355 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
| 356 |
+
|
| 357 |
+
elif mode == "vanilla":
|
| 358 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 359 |
+
|
| 360 |
+
b, a, s, _ = q.shape
|
| 361 |
+
s1 = k.size(2)
|
| 362 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
| 363 |
+
if causal:
|
| 364 |
+
# Only applied to self attention
|
| 365 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
| 366 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
| 367 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 368 |
+
attn_bias.to(q.dtype)
|
| 369 |
+
|
| 370 |
+
if attn_mask is not None:
|
| 371 |
+
if attn_mask.dtype == torch.bool:
|
| 372 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 373 |
+
else:
|
| 374 |
+
attn_bias += attn_mask
|
| 375 |
+
|
| 376 |
+
# TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
|
| 377 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
| 378 |
+
attn += attn_bias
|
| 379 |
+
attn = attn.softmax(dim=-1)
|
| 380 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
| 381 |
+
x = attn @ v
|
| 382 |
+
else:
|
| 383 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
| 384 |
+
|
| 385 |
+
if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
|
| 386 |
+
x = post_attn_layout(x).contiguous()
|
| 387 |
+
b, s, a, d = x.shape
|
| 388 |
+
out = x.reshape(b, s, -1)
|
| 389 |
+
return out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class SelfAttentionLayer(BasicAttentionLayer):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
dim,
|
| 396 |
+
num_heads,
|
| 397 |
+
qkv_bias=True,
|
| 398 |
+
qk_norm=True,
|
| 399 |
+
attn_drop=0,
|
| 400 |
+
proj_drop=0,
|
| 401 |
+
dtype=None,
|
| 402 |
+
device=None,
|
| 403 |
+
norm_type="layer",
|
| 404 |
+
attn_mode="self_flash",
|
| 405 |
+
deterministic=False,
|
| 406 |
+
) -> None:
|
| 407 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 408 |
+
super().__init__(attn_mode, deterministic)
|
| 409 |
+
self.dim = dim
|
| 410 |
+
self.num_heads = num_heads
|
| 411 |
+
assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
|
| 412 |
+
self.head_dim = self.dim // num_heads
|
| 413 |
+
self.attn_drop = attn_drop
|
| 414 |
+
|
| 415 |
+
# This assertion is aligned with flash attention
|
| 416 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 417 |
+
|
| 418 |
+
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
|
| 419 |
+
|
| 420 |
+
norm_layer = get_norm_layer(norm_type)
|
| 421 |
+
self.q_norm = (
|
| 422 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 423 |
+
)
|
| 424 |
+
self.k_norm = (
|
| 425 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
|
| 429 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 430 |
+
|
| 431 |
+
def forward(self, x, freqs_cis=None, attn_mask=None):
|
| 432 |
+
"""
|
| 433 |
+
Args:
|
| 434 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
| 435 |
+
freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
|
| 436 |
+
attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
|
| 437 |
+
"""
|
| 438 |
+
b, s, d = x.shape
|
| 439 |
+
|
| 440 |
+
# Apply QKV projection
|
| 441 |
+
qkv = self.Wqkv(x)
|
| 442 |
+
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
|
| 443 |
+
q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
|
| 444 |
+
|
| 445 |
+
# Apply QK-Norm if needed
|
| 446 |
+
q = self.q_norm(q)
|
| 447 |
+
k = self.k_norm(k)
|
| 448 |
+
|
| 449 |
+
# Apply RoPE if needed
|
| 450 |
+
if freqs_cis is not None:
|
| 451 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis)
|
| 452 |
+
assert (
|
| 453 |
+
qq.shape == q.shape and kk.shape == k.shape
|
| 454 |
+
), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
|
| 455 |
+
q, k = qq, kk
|
| 456 |
+
|
| 457 |
+
# Apply self attention
|
| 458 |
+
context = attention(
|
| 459 |
+
q,
|
| 460 |
+
k,
|
| 461 |
+
v,
|
| 462 |
+
drop_rate=self.attn_drop if self.training else 0,
|
| 463 |
+
attn_mask=attn_mask,
|
| 464 |
+
mode=self.attn_mode,
|
| 465 |
+
deterministic=self.deterministic,
|
| 466 |
+
)
|
| 467 |
+
out = self.out_proj(context)
|
| 468 |
+
out = self.proj_drop(out)
|
| 469 |
+
|
| 470 |
+
return out
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class CrossAttentionLayer(BasicAttentionLayer):
|
| 474 |
+
def __init__(
|
| 475 |
+
self,
|
| 476 |
+
qdim,
|
| 477 |
+
kdim,
|
| 478 |
+
num_heads,
|
| 479 |
+
qkv_bias=True,
|
| 480 |
+
qk_norm=True,
|
| 481 |
+
attn_drop=0,
|
| 482 |
+
proj_drop=0,
|
| 483 |
+
dtype=None,
|
| 484 |
+
device=None,
|
| 485 |
+
norm_type="layer",
|
| 486 |
+
attn_mode="cross_flash",
|
| 487 |
+
deterministic=False,
|
| 488 |
+
):
|
| 489 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 490 |
+
super().__init__(attn_mode, deterministic)
|
| 491 |
+
self.qdim = qdim
|
| 492 |
+
self.kdim = kdim
|
| 493 |
+
self.num_heads = num_heads
|
| 494 |
+
assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
|
| 495 |
+
self.head_dim = self.qdim // num_heads
|
| 496 |
+
self.attn_drop = attn_drop
|
| 497 |
+
|
| 498 |
+
# This assertion is aligned with flash attention
|
| 499 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 500 |
+
|
| 501 |
+
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
| 502 |
+
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
| 503 |
+
|
| 504 |
+
norm_layer = get_norm_layer(norm_type)
|
| 505 |
+
self.q_norm = (
|
| 506 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 507 |
+
)
|
| 508 |
+
self.k_norm = (
|
| 509 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
| 513 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 514 |
+
|
| 515 |
+
def forward(self, x, y, attn_mask=None):
|
| 516 |
+
"""
|
| 517 |
+
Args:
|
| 518 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
| 519 |
+
y (torch.Tensor): (batch, seq_len1, hidden_dim1)
|
| 520 |
+
attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
|
| 521 |
+
"""
|
| 522 |
+
b, s, d = x.shape
|
| 523 |
+
_, s1, d1 = y.shape
|
| 524 |
+
|
| 525 |
+
q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
|
| 526 |
+
kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
|
| 527 |
+
k, v = kv.unbind(dim=2)
|
| 528 |
+
|
| 529 |
+
# Apply QK-Norm if needed
|
| 530 |
+
q = self.q_norm(q)
|
| 531 |
+
k = self.k_norm(k)
|
| 532 |
+
|
| 533 |
+
# Apply cross attention
|
| 534 |
+
context = attention(
|
| 535 |
+
q,
|
| 536 |
+
k,
|
| 537 |
+
v,
|
| 538 |
+
attn_mask=attn_mask,
|
| 539 |
+
drop_rate=self.attn_drop if self.training else 0,
|
| 540 |
+
mode=self.attn_mode,
|
| 541 |
+
deterministic=self.deterministic,
|
| 542 |
+
)
|
| 543 |
+
out = self.out_proj(context)
|
| 544 |
+
out = self.proj_drop(out)
|
| 545 |
+
|
| 546 |
+
return out
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/embed_layers.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from ...utils.helper import to_2tuple, to_1tuple
|
| 6 |
+
|
| 7 |
+
class PatchEmbed1D(nn.Module):
|
| 8 |
+
"""1D Audio to Patch Embedding
|
| 9 |
+
|
| 10 |
+
A convolution based approach to patchifying a 1D audio w/ embedding projection.
|
| 11 |
+
|
| 12 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
| 13 |
+
|
| 14 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
patch_size=1,
|
| 20 |
+
in_chans=768,
|
| 21 |
+
embed_dim=768,
|
| 22 |
+
norm_layer=None,
|
| 23 |
+
flatten=True,
|
| 24 |
+
bias=True,
|
| 25 |
+
dtype=None,
|
| 26 |
+
device=None,
|
| 27 |
+
):
|
| 28 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 29 |
+
super().__init__()
|
| 30 |
+
patch_size = to_1tuple(patch_size)
|
| 31 |
+
self.patch_size = patch_size
|
| 32 |
+
self.flatten = flatten
|
| 33 |
+
|
| 34 |
+
self.proj = nn.Conv1d(
|
| 35 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
|
| 36 |
+
)
|
| 37 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
| 38 |
+
if bias:
|
| 39 |
+
nn.init.zeros_(self.proj.bias)
|
| 40 |
+
|
| 41 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
assert (
|
| 45 |
+
x.shape[2] % self.patch_size[0] == 0
|
| 46 |
+
), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
|
| 47 |
+
|
| 48 |
+
x = self.proj(x)
|
| 49 |
+
if self.flatten:
|
| 50 |
+
x = x.transpose(1, 2) # BCN -> BNC
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ConditionProjection(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
Projects condition embeddings. Also handles dropout for classifier-free guidance.
|
| 58 |
+
|
| 59 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
| 63 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
| 66 |
+
self.act_1 = act_layer()
|
| 67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
| 68 |
+
|
| 69 |
+
def forward(self, caption):
|
| 70 |
+
hidden_states = self.linear_1(caption)
|
| 71 |
+
hidden_states = self.act_1(hidden_states)
|
| 72 |
+
hidden_states = self.linear_2(hidden_states)
|
| 73 |
+
return hidden_states
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 77 |
+
"""
|
| 78 |
+
Create sinusoidal timestep embeddings.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 82 |
+
dim (int): the dimension of the output.
|
| 83 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
| 87 |
+
|
| 88 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 89 |
+
"""
|
| 90 |
+
half = dim // 2
|
| 91 |
+
freqs = torch.exp(
|
| 92 |
+
-math.log(max_period)
|
| 93 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 94 |
+
/ half
|
| 95 |
+
).to(device=t.device)
|
| 96 |
+
args = t[:, None].float() * freqs[None]
|
| 97 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 98 |
+
if dim % 2:
|
| 99 |
+
embedding = torch.cat(
|
| 100 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 101 |
+
)
|
| 102 |
+
return embedding
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class TimestepEmbedder(nn.Module):
|
| 106 |
+
"""
|
| 107 |
+
Embeds scalar timesteps into vector representations.
|
| 108 |
+
"""
|
| 109 |
+
def __init__(self,
|
| 110 |
+
hidden_size,
|
| 111 |
+
act_layer,
|
| 112 |
+
frequency_embedding_size=256,
|
| 113 |
+
max_period=10000,
|
| 114 |
+
out_size=None,
|
| 115 |
+
dtype=None,
|
| 116 |
+
device=None
|
| 117 |
+
):
|
| 118 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 121 |
+
self.max_period = max_period
|
| 122 |
+
if out_size is None:
|
| 123 |
+
out_size = hidden_size
|
| 124 |
+
|
| 125 |
+
self.mlp = nn.Sequential(
|
| 126 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
| 127 |
+
act_layer(),
|
| 128 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
| 129 |
+
)
|
| 130 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
| 131 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
| 132 |
+
|
| 133 |
+
def forward(self, t):
|
| 134 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
| 135 |
+
t_emb = self.mlp(t_freq)
|
| 136 |
+
return t_emb
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/mlp_layers.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from timm library:
|
| 2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .modulate_layers import modulate
|
| 11 |
+
from ...utils.helper import to_2tuple
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
in_channels,
|
| 19 |
+
hidden_channels=None,
|
| 20 |
+
out_features=None,
|
| 21 |
+
act_layer=nn.GELU,
|
| 22 |
+
norm_layer=None,
|
| 23 |
+
bias=True,
|
| 24 |
+
drop=0.0,
|
| 25 |
+
use_conv=False,
|
| 26 |
+
device=None,
|
| 27 |
+
dtype=None,
|
| 28 |
+
):
|
| 29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 30 |
+
super().__init__()
|
| 31 |
+
out_features = out_features or in_channels
|
| 32 |
+
hidden_channels = hidden_channels or in_channels
|
| 33 |
+
bias = to_2tuple(bias)
|
| 34 |
+
drop_probs = to_2tuple(drop)
|
| 35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 36 |
+
|
| 37 |
+
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
|
| 38 |
+
self.act = act_layer()
|
| 39 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 40 |
+
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
|
| 41 |
+
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
|
| 42 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.fc1(x)
|
| 46 |
+
x = self.act(x)
|
| 47 |
+
x = self.drop1(x)
|
| 48 |
+
x = self.norm(x)
|
| 49 |
+
x = self.fc2(x)
|
| 50 |
+
x = self.drop2(x)
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
| 55 |
+
# only used when use_vanilla is True
|
| 56 |
+
class MLPEmbedder(nn.Module):
|
| 57 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
| 58 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
| 61 |
+
self.silu = nn.SiLU()
|
| 62 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LinearWarpforSingle(nn.Module):
|
| 69 |
+
def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
|
| 70 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, y):
|
| 75 |
+
z = torch.cat([x, y], dim=2)
|
| 76 |
+
return self.fc(z)
|
| 77 |
+
|
| 78 |
+
class FinalLayer1D(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
|
| 80 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
# Just use LayerNorm for the final layer
|
| 84 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 85 |
+
self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
|
| 86 |
+
nn.init.zeros_(self.linear.weight)
|
| 87 |
+
nn.init.zeros_(self.linear.bias)
|
| 88 |
+
|
| 89 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
| 90 |
+
self.adaLN_modulation = nn.Sequential(
|
| 91 |
+
act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
|
| 92 |
+
)
|
| 93 |
+
# Zero-initialize the modulation
|
| 94 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
| 95 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, c):
|
| 98 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 99 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
| 100 |
+
x = self.linear(x)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ChannelLastConv1d(nn.Conv1d):
|
| 105 |
+
|
| 106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
x = x.permute(0, 2, 1)
|
| 108 |
+
x = super().forward(x)
|
| 109 |
+
x = x.permute(0, 2, 1)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ConvMLP(nn.Module):
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
dim: int,
|
| 118 |
+
hidden_dim: int,
|
| 119 |
+
multiple_of: int = 256,
|
| 120 |
+
kernel_size: int = 3,
|
| 121 |
+
padding: int = 1,
|
| 122 |
+
device=None,
|
| 123 |
+
dtype=None,
|
| 124 |
+
):
|
| 125 |
+
"""
|
| 126 |
+
Convolutional MLP module.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
dim (int): Input dimension.
|
| 130 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
| 131 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
| 132 |
+
|
| 133 |
+
Attributes:
|
| 134 |
+
w1: Linear transformation for the first layer.
|
| 135 |
+
w2: Linear transformation for the second layer.
|
| 136 |
+
w3: Linear transformation for the third layer.
|
| 137 |
+
|
| 138 |
+
"""
|
| 139 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 140 |
+
super().__init__()
|
| 141 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 142 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 143 |
+
|
| 144 |
+
self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
| 145 |
+
self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
| 146 |
+
self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/modulate_layers.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class ModulateDiT(nn.Module):
|
| 6 |
+
def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
|
| 7 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.act = act_layer()
|
| 10 |
+
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
|
| 11 |
+
# Zero-initialize the modulation
|
| 12 |
+
nn.init.zeros_(self.linear.weight)
|
| 13 |
+
nn.init.zeros_(self.linear.bias)
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return self.linear(self.act(x))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def modulate(x, shift=None, scale=None):
|
| 20 |
+
if x.ndim == 3:
|
| 21 |
+
shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
|
| 22 |
+
scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
|
| 23 |
+
if scale is None and shift is None:
|
| 24 |
+
return x
|
| 25 |
+
elif shift is None:
|
| 26 |
+
return x * (1 + scale)
|
| 27 |
+
elif scale is None:
|
| 28 |
+
return x + shift
|
| 29 |
+
else:
|
| 30 |
+
return x * (1 + scale) + shift
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def apply_gate(x, gate=None, tanh=False):
|
| 34 |
+
if gate is None:
|
| 35 |
+
return x
|
| 36 |
+
if gate.ndim == 2 and x.ndim == 3:
|
| 37 |
+
gate = gate.unsqueeze(1)
|
| 38 |
+
if tanh:
|
| 39 |
+
return x * gate.tanh()
|
| 40 |
+
else:
|
| 41 |
+
return x * gate
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def ckpt_wrapper(module):
|
| 45 |
+
def ckpt_forward(*inputs):
|
| 46 |
+
outputs = module(*inputs)
|
| 47 |
+
return outputs
|
| 48 |
+
|
| 49 |
+
return ckpt_forward
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/norm_layers.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class RMSNorm(nn.Module):
|
| 5 |
+
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
|
| 6 |
+
device=None, dtype=None):
|
| 7 |
+
"""
|
| 8 |
+
Initialize the RMSNorm normalization layer.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
dim (int): The dimension of the input tensor.
|
| 12 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 16 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.eps = eps
|
| 22 |
+
if elementwise_affine:
|
| 23 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
| 24 |
+
|
| 25 |
+
def _norm(self, x):
|
| 26 |
+
"""
|
| 27 |
+
Apply the RMSNorm normalization to the input tensor.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
x (torch.Tensor): The input tensor.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
torch.Tensor: The normalized tensor.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
"""
|
| 40 |
+
Forward pass through the RMSNorm layer.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
x (torch.Tensor): The input tensor.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
output = self._norm(x.float()).type_as(x)
|
| 50 |
+
if hasattr(self, "weight"):
|
| 51 |
+
output = output * self.weight
|
| 52 |
+
return output
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_norm_layer(norm_layer):
|
| 56 |
+
"""
|
| 57 |
+
Get the normalization layer.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
norm_layer (str): The type of normalization layer.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
norm_layer (nn.Module): The normalization layer.
|
| 64 |
+
"""
|
| 65 |
+
if norm_layer == "layer":
|
| 66 |
+
return nn.LayerNorm
|
| 67 |
+
elif norm_layer == "rms":
|
| 68 |
+
return RMSNorm
|
| 69 |
+
else:
|
| 70 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/posemb_layers.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Union, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _to_tuple(x, dim=2):
|
| 6 |
+
if isinstance(x, int):
|
| 7 |
+
return (x,) * dim
|
| 8 |
+
elif len(x) == dim:
|
| 9 |
+
return x
|
| 10 |
+
else:
|
| 11 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
| 15 |
+
"""
|
| 16 |
+
Get n-D meshgrid with start, stop and num.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
| 20 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
| 21 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
| 22 |
+
n-tuples.
|
| 23 |
+
*args: See above.
|
| 24 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
grid (np.ndarray): [dim, ...]
|
| 28 |
+
"""
|
| 29 |
+
if len(args) == 0:
|
| 30 |
+
# start is grid_size
|
| 31 |
+
num = _to_tuple(start, dim=dim)
|
| 32 |
+
start = (0,) * dim
|
| 33 |
+
stop = num
|
| 34 |
+
elif len(args) == 1:
|
| 35 |
+
# start is start, args[0] is stop, step is 1
|
| 36 |
+
start = _to_tuple(start, dim=dim)
|
| 37 |
+
stop = _to_tuple(args[0], dim=dim)
|
| 38 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
| 39 |
+
elif len(args) == 2:
|
| 40 |
+
# start is start, args[0] is stop, args[1] is num
|
| 41 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
| 42 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
| 43 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
| 46 |
+
|
| 47 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
| 48 |
+
axis_grid = []
|
| 49 |
+
for i in range(dim):
|
| 50 |
+
a, b, n = start[i], stop[i], num[i]
|
| 51 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
| 52 |
+
axis_grid.append(g)
|
| 53 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
| 54 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
| 55 |
+
|
| 56 |
+
return grid
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#################################################################################
|
| 60 |
+
# Rotary Positional Embedding Functions #
|
| 61 |
+
#################################################################################
|
| 62 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_nd_rotary_pos_embed(
|
| 66 |
+
rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
| 73 |
+
sum(rope_dim_list) should equal to head_dim of attention layer.
|
| 74 |
+
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
| 75 |
+
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
| 76 |
+
*args: See above.
|
| 77 |
+
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 78 |
+
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 79 |
+
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
| 80 |
+
part and an imaginary part separately.
|
| 81 |
+
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
| 82 |
+
freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
pos_embed (torch.Tensor): [HW, D/2]
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
|
| 89 |
+
|
| 90 |
+
# use 1/ndim of dimensions to encode grid_axis
|
| 91 |
+
embs = []
|
| 92 |
+
for i in range(len(rope_dim_list)):
|
| 93 |
+
emb = get_1d_rotary_pos_embed(
|
| 94 |
+
rope_dim_list[i],
|
| 95 |
+
grid[i].reshape(-1),
|
| 96 |
+
theta,
|
| 97 |
+
use_real=use_real,
|
| 98 |
+
theta_rescale_factor=theta_rescale_factor,
|
| 99 |
+
freq_scaling=freq_scaling,
|
| 100 |
+
) # 2 x [WHD, rope_dim_list[i]]
|
| 101 |
+
embs.append(emb)
|
| 102 |
+
|
| 103 |
+
if use_real:
|
| 104 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
| 105 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
| 106 |
+
return cos, sin
|
| 107 |
+
else:
|
| 108 |
+
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
| 109 |
+
return emb
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_1d_rotary_pos_embed(
|
| 113 |
+
dim: int,
|
| 114 |
+
pos: Union[torch.FloatTensor, int],
|
| 115 |
+
theta: float = 10000.0,
|
| 116 |
+
use_real: bool = False,
|
| 117 |
+
theta_rescale_factor: float = 1.0,
|
| 118 |
+
freq_scaling: float = 1.0,
|
| 119 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 120 |
+
"""
|
| 121 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
| 122 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
| 123 |
+
|
| 124 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
| 125 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 126 |
+
The returned tensor contains complex values in complex64 data type.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
dim (int): Dimension of the frequency tensor.
|
| 130 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
| 131 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 132 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
| 133 |
+
Otherwise, return complex numbers.
|
| 134 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
| 135 |
+
freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
| 139 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
| 140 |
+
"""
|
| 141 |
+
if isinstance(pos, int):
|
| 142 |
+
pos = torch.arange(pos).float()
|
| 143 |
+
|
| 144 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 145 |
+
# has some connection to NTK literature
|
| 146 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 147 |
+
if theta_rescale_factor != 1.0:
|
| 148 |
+
theta *= theta_rescale_factor ** (dim / (dim - 1))
|
| 149 |
+
|
| 150 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
| 151 |
+
freqs *= freq_scaling
|
| 152 |
+
freqs = torch.outer(pos, freqs) # [S, D/2]
|
| 153 |
+
if use_real:
|
| 154 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
| 155 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
| 156 |
+
return freqs_cos, freqs_sin
|
| 157 |
+
else:
|
| 158 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 159 |
+
return freqs_cis
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .synchformer import Synchformer
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/ast_model.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 5 |
+
|
| 6 |
+
from .modeling_ast import ASTForAudioClassification, ASTConfig
|
| 7 |
+
from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer
|
| 8 |
+
from .utils import check_if_file_exists_else_download
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AST(torch.nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
extract_features: bool = False,
|
| 15 |
+
ckpt_path: str = None,
|
| 16 |
+
feat_type: str = None,
|
| 17 |
+
max_spec_t: int = None,
|
| 18 |
+
factorize_freq_time: bool = None,
|
| 19 |
+
agg_freq_module: str = None,
|
| 20 |
+
agg_time_module: str = None,
|
| 21 |
+
add_global_repr: bool = True,
|
| 22 |
+
agg_segments_module: str = None,
|
| 23 |
+
max_segments: int = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
extract_features: if True, then the model will return the features instead of head's output
|
| 27 |
+
ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
|
| 28 |
+
feat_type: if extract_features is True, this parameter specifies the type of features to return
|
| 29 |
+
max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
|
| 30 |
+
factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
|
| 31 |
+
agg_freq_module: if specified, then the model will use this module for freq aggregation
|
| 32 |
+
agg_time_module: if specified, then the model will use this module for time aggregation
|
| 33 |
+
add_global_repr: if True, adds a global representation to the features (aggregation on segments)
|
| 34 |
+
agg_segments_module: if specified, then the model will use this module for segments aggregation
|
| 35 |
+
max_segments: if specified, the initialization of PE in the global agg module will use this value.
|
| 36 |
+
This should correspond to the max number of segments per video (if None, 16 is used)
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.extract_features = extract_features
|
| 40 |
+
self.ckpt_path = ckpt_path
|
| 41 |
+
self.max_spec_t = max_spec_t
|
| 42 |
+
self.max_segments = max_segments
|
| 43 |
+
|
| 44 |
+
# depending on whether the feat extractor was pre-trained contrastively or not, we need to
|
| 45 |
+
# load the state dict differently.
|
| 46 |
+
|
| 47 |
+
# if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
|
| 48 |
+
if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593":
|
| 49 |
+
revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4)
|
| 50 |
+
self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
|
| 51 |
+
full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
|
| 52 |
+
logging.info(f"Loaded AST from {ckpt_path}")
|
| 53 |
+
else:
|
| 54 |
+
self.config = ASTConfig()
|
| 55 |
+
self.config.num_labels = 527 # 2 by default, audioset has 527 labels
|
| 56 |
+
full_model = ASTForAudioClassification(self.config)
|
| 57 |
+
logging.info("Initialized AST from scratch with the AST AudioSet config")
|
| 58 |
+
|
| 59 |
+
was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt")
|
| 60 |
+
|
| 61 |
+
# feature extractor
|
| 62 |
+
self.ast = full_model.audio_spectrogram_transformer
|
| 63 |
+
|
| 64 |
+
if self.extract_features:
|
| 65 |
+
# assign `feat_type` (use default if not specified)
|
| 66 |
+
self.feat_type = "last_hidden_state" if feat_type is None else feat_type
|
| 67 |
+
# define adapters if needed
|
| 68 |
+
self.factorize_freq_time = factorize_freq_time
|
| 69 |
+
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
| 70 |
+
transf_enc_layer_kwargs = dict(
|
| 71 |
+
d_model=self.config.hidden_size,
|
| 72 |
+
nhead=self.config.num_attention_heads,
|
| 73 |
+
dim_feedforward=self.config.intermediate_size,
|
| 74 |
+
activation=torch.nn.GELU(),
|
| 75 |
+
batch_first=True,
|
| 76 |
+
dropout=self.config.attention_probs_dropout_prob,
|
| 77 |
+
layer_norm_eps=1e-6,
|
| 78 |
+
norm_first=True,
|
| 79 |
+
)
|
| 80 |
+
if factorize_freq_time:
|
| 81 |
+
self.feat_type = "last_hidden_state" # this feat_type supports factorization
|
| 82 |
+
# frequency aggreration
|
| 83 |
+
if agg_freq_module == "TransformerEncoderLayer":
|
| 84 |
+
self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
| 85 |
+
elif agg_freq_module == "AveragePooling":
|
| 86 |
+
self.freq_attn_agg = AveragePooling(
|
| 87 |
+
avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D"
|
| 88 |
+
)
|
| 89 |
+
# time aggreration
|
| 90 |
+
if agg_time_module == "TransformerEncoderLayer":
|
| 91 |
+
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
| 92 |
+
elif agg_time_module == "AveragePooling":
|
| 93 |
+
self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D")
|
| 94 |
+
elif "Identity" in agg_time_module:
|
| 95 |
+
self.temp_attn_agg = torch.nn.Identity()
|
| 96 |
+
# define a global aggregation layer (aggregarate over segments)
|
| 97 |
+
self.add_global_repr = add_global_repr
|
| 98 |
+
if add_global_repr:
|
| 99 |
+
if agg_segments_module == "TransformerEncoderLayer":
|
| 100 |
+
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
| 101 |
+
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
| 102 |
+
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
| 103 |
+
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
| 104 |
+
add_pos_emb=True,
|
| 105 |
+
pos_emb_drop=self.config.hidden_dropout_prob,
|
| 106 |
+
pos_max_len=pos_max_len,
|
| 107 |
+
**transf_enc_layer_kwargs,
|
| 108 |
+
)
|
| 109 |
+
elif agg_segments_module == "AveragePooling":
|
| 110 |
+
self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D")
|
| 111 |
+
else:
|
| 112 |
+
self.classifier = full_model.classifier
|
| 113 |
+
|
| 114 |
+
# AST.device fails with AttributeError. This is a workaround
|
| 115 |
+
self.device = full_model.device
|
| 116 |
+
|
| 117 |
+
# pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
|
| 118 |
+
self.patch_position_emb()
|
| 119 |
+
|
| 120 |
+
if was_pt_on_avclip:
|
| 121 |
+
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
| 122 |
+
# and keep only the state_dict of the feat extractor
|
| 123 |
+
check_if_file_exists_else_download(self.ckpt_path)
|
| 124 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 125 |
+
ckpt_weights = dict()
|
| 126 |
+
for k, v in ckpt["state_dict"].items():
|
| 127 |
+
if k.startswith(("module.a_encoder.", "a_encoder.")):
|
| 128 |
+
k = k.replace("module.", "").replace("a_encoder.", "")
|
| 129 |
+
ckpt_weights[k] = v
|
| 130 |
+
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
| 131 |
+
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
| 132 |
+
logging.warning(
|
| 133 |
+
f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n"
|
| 134 |
+
f"Missing keys ({len(_load_status.missing_keys)}): "
|
| 135 |
+
f"{_load_status.missing_keys}, \n"
|
| 136 |
+
f"Unexpected keys ({len(_load_status.unexpected_keys)}): "
|
| 137 |
+
f"{_load_status.unexpected_keys} \n"
|
| 138 |
+
f"temp_attn_agg are expected to be missing if ckpt was pt contrastively."
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.")
|
| 142 |
+
|
| 143 |
+
# print the number of parameters
|
| 144 |
+
logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
"""
|
| 150 |
+
x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
|
| 151 |
+
ast_kwargs: additional arguments for the AST model
|
| 152 |
+
cont_mask: (B, S, T, F) where 0s are the values to be masked out
|
| 153 |
+
if `for_loop=True`, we use a for loop to extract features for each segment separately.
|
| 154 |
+
if `for_loop=False`, we extract features for all segments at once.
|
| 155 |
+
Using the for loop is slower but more memory efficient, while using all segments at once
|
| 156 |
+
is faster but more memory inefficient.
|
| 157 |
+
Using for loop allows to control the memory footprint by varying the number of videos in a
|
| 158 |
+
batch (batch size) rather than the number of segments in a video.
|
| 159 |
+
"""
|
| 160 |
+
B, S, T, F = x.shape
|
| 161 |
+
|
| 162 |
+
if for_loop:
|
| 163 |
+
assert cont_mask is None, "cont_mask is not supported with for_loop=True"
|
| 164 |
+
orig_shape_s = (B, 1, T, F)
|
| 165 |
+
# NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
|
| 166 |
+
# (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
|
| 167 |
+
x = torch.cat(
|
| 168 |
+
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
orig_shape = (B, S, T, F)
|
| 172 |
+
x = x.view(B * S, T, F)
|
| 173 |
+
if cont_mask is not None:
|
| 174 |
+
cont_mask = cont_mask.reshape(B * S, T, F)
|
| 175 |
+
# AST expects a tensor of shape (B*S, T, F).
|
| 176 |
+
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
|
| 177 |
+
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
| 178 |
+
x = x.view(B, S, *x.shape[1:])
|
| 179 |
+
# x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
|
| 180 |
+
|
| 181 |
+
global_x = None
|
| 182 |
+
if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
|
| 183 |
+
assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}"
|
| 184 |
+
global_x = self.global_attn_agg(x) # (B, D)
|
| 185 |
+
|
| 186 |
+
return x, global_x # x is (B, S, ...), global_x is (B, D) or None
|
| 187 |
+
|
| 188 |
+
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
|
| 189 |
+
"""x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out"""
|
| 190 |
+
# 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>]
|
| 191 |
+
# x_mask is (B, T) where 0s are the values to be masked out
|
| 192 |
+
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
|
| 193 |
+
|
| 194 |
+
if self.extract_features:
|
| 195 |
+
x = self.get_features_by_type(x)
|
| 196 |
+
if self.factorize_freq_time:
|
| 197 |
+
x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
| 198 |
+
if cont_mask is not None:
|
| 199 |
+
# duplicating the mask for the latent dimension (D) to be compatible with the next func
|
| 200 |
+
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
|
| 201 |
+
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
|
| 202 |
+
# again removing the latent
|
| 203 |
+
x_mask = x_mask[:, 0, :, :]
|
| 204 |
+
else:
|
| 205 |
+
x_mask = None
|
| 206 |
+
x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
|
| 207 |
+
x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
|
| 208 |
+
else:
|
| 209 |
+
x = x["pooler_output"]
|
| 210 |
+
x = self.classifier(x)
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
|
| 214 |
+
if self.feat_type == "pooler_output":
|
| 215 |
+
return x["pooler_output"] # (B, D)
|
| 216 |
+
elif self.feat_type == "CLS":
|
| 217 |
+
return x["last_hidden_state"][:, 0, :] # (B, D)
|
| 218 |
+
elif self.feat_type == "last_hidden_state":
|
| 219 |
+
return x["last_hidden_state"] # (B, 2+T, D)
|
| 220 |
+
elif self.feat_type == "last_hidden_state_no_AUX":
|
| 221 |
+
return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f"Unknown feature type: {self.feat_type}")
|
| 224 |
+
|
| 225 |
+
def restore_freq_temp_dims(self, feats, orig_shape: tuple):
|
| 226 |
+
"""
|
| 227 |
+
feats are of shape (B*S, T, D)
|
| 228 |
+
where T = 2 + f * t (if feat_type == 'last_hidden_state')
|
| 229 |
+
where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
|
| 230 |
+
Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
|
| 231 |
+
From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
|
| 232 |
+
`feats.transpose(1, 2).view(B*S, D, f, t)`
|
| 233 |
+
|
| 234 |
+
(Similar function is defined in for RGB features in `motionformer.py`)
|
| 235 |
+
"""
|
| 236 |
+
B, S, T, F = orig_shape
|
| 237 |
+
D = self.config.hidden_size
|
| 238 |
+
|
| 239 |
+
# num patches in each dimension
|
| 240 |
+
f, t = self.ast.embeddings.get_shape(self.config)
|
| 241 |
+
|
| 242 |
+
if self.feat_type == "last_hidden_state":
|
| 243 |
+
feats = feats[:, 2:, :] # removing CLS and distill tokens
|
| 244 |
+
|
| 245 |
+
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
| 246 |
+
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
|
| 247 |
+
|
| 248 |
+
return feats
|
| 249 |
+
|
| 250 |
+
def patch_position_emb(self):
|
| 251 |
+
if self.max_spec_t is not None:
|
| 252 |
+
self.config.max_length = self.max_spec_t
|
| 253 |
+
f, t = self.ast.embeddings.get_shape(self.config)
|
| 254 |
+
shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens
|
| 255 |
+
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
|
| 256 |
+
|
| 257 |
+
def to(self, device):
|
| 258 |
+
"""AST.device fails with AttributeError. This is a workaround."""
|
| 259 |
+
self.device = torch.device(device)
|
| 260 |
+
return super().to(device)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
|
| 264 |
+
"""This layer is used to aggregate the features along the frequency axis.
|
| 265 |
+
It follows the same logic as spatio-temporal aggregation in visual feature extractor.
|
| 266 |
+
Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, *args, **kwargs):
|
| 269 |
+
super().__init__(*args, **kwargs)
|
| 270 |
+
|
| 271 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
| 272 |
+
"""x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out"""
|
| 273 |
+
BS, D, f, t = x.shape
|
| 274 |
+
|
| 275 |
+
# time as a batch dimension
|
| 276 |
+
x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
|
| 277 |
+
x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
|
| 278 |
+
# similar to mask
|
| 279 |
+
if x_mask is not None:
|
| 280 |
+
x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
|
| 281 |
+
x_mask = x_mask.reshape(BS * t, f)
|
| 282 |
+
|
| 283 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
| 284 |
+
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
| 285 |
+
|
| 286 |
+
# reshape back to (B*S, t, D)
|
| 287 |
+
x = x.view(BS, t, D)
|
| 288 |
+
|
| 289 |
+
return x # (B*S, t, D)
|
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/compute_desync_score.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import subprocess
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import torchvision
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
import data_transforms
|
| 11 |
+
from .synchformer import Synchformer
|
| 12 |
+
from .data_transforms import make_class_grid, quantize_offset
|
| 13 |
+
from .utils import check_if_file_exists_else_download, which_ffmpeg
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def prepare_inputs(batch, device):
|
| 17 |
+
aud = batch["audio"].to(device)
|
| 18 |
+
vid = batch["video"].to(device)
|
| 19 |
+
|
| 20 |
+
return aud, vid
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_test_transforms():
|
| 24 |
+
ts = [
|
| 25 |
+
data_transforms.EqualifyFromRight(),
|
| 26 |
+
data_transforms.RGBSpatialCrop(input_size=224, is_random=False),
|
| 27 |
+
data_transforms.TemporalCropAndOffset(
|
| 28 |
+
crop_len_sec=5,
|
| 29 |
+
max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
|
| 30 |
+
max_wiggle_sec=0.0,
|
| 31 |
+
do_offset=True,
|
| 32 |
+
offset_type="grid",
|
| 33 |
+
prob_oos="null",
|
| 34 |
+
grid_size=21,
|
| 35 |
+
segment_size_vframes=16,
|
| 36 |
+
n_segments=14,
|
| 37 |
+
step_size_seg=0.5,
|
| 38 |
+
vfps=25,
|
| 39 |
+
),
|
| 40 |
+
data_transforms.GenerateMultipleSegments(
|
| 41 |
+
segment_size_vframes=16,
|
| 42 |
+
n_segments=14,
|
| 43 |
+
is_start_random=False,
|
| 44 |
+
step_size_seg=0.5,
|
| 45 |
+
),
|
| 46 |
+
data_transforms.RGBToHalfToZeroOne(),
|
| 47 |
+
data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization
|
| 48 |
+
data_transforms.AudioMelSpectrogram(
|
| 49 |
+
sample_rate=16000,
|
| 50 |
+
win_length=400, # 25 ms * 16 kHz
|
| 51 |
+
hop_length=160, # 10 ms * 16 kHz
|
| 52 |
+
n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate)))
|
| 53 |
+
n_mels=128, # as in AST
|
| 54 |
+
),
|
| 55 |
+
data_transforms.AudioLog(),
|
| 56 |
+
data_transforms.PadOrTruncate(max_spec_t=66),
|
| 57 |
+
data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet
|
| 58 |
+
data_transforms.PermuteStreams(
|
| 59 |
+
einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same
|
| 60 |
+
),
|
| 61 |
+
]
|
| 62 |
+
transforms = torchvision.transforms.Compose(ts)
|
| 63 |
+
|
| 64 |
+
return transforms
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
|
| 68 |
+
orig_path = path
|
| 69 |
+
# (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
|
| 70 |
+
rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW")
|
| 71 |
+
assert meta["video_fps"], f"No video fps for {orig_path}"
|
| 72 |
+
# (Ta) <- (Ca, Ta)
|
| 73 |
+
audio = audio.mean(dim=0)
|
| 74 |
+
# FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
|
| 75 |
+
meta = {
|
| 76 |
+
"video": {"fps": [meta["video_fps"]]},
|
| 77 |
+
"audio": {"framerate": [meta["audio_fps"]]},
|
| 78 |
+
}
|
| 79 |
+
return rgb, audio, meta
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def reencode_video(path, vfps=25, afps=16000, in_size=256):
|
| 83 |
+
assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated."
|
| 84 |
+
new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4"
|
| 85 |
+
new_path.parent.mkdir(exist_ok=True)
|
| 86 |
+
new_path = str(new_path)
|
| 87 |
+
cmd = f"{which_ffmpeg()}"
|
| 88 |
+
# no info/error printing
|
| 89 |
+
cmd += " -hide_banner -loglevel panic"
|
| 90 |
+
cmd += f" -y -i {path}"
|
| 91 |
+
# 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
|
| 92 |
+
cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
|
| 93 |
+
cmd += f" -ar {afps}"
|
| 94 |
+
cmd += f" {new_path}"
|
| 95 |
+
subprocess.call(cmd.split())
|
| 96 |
+
cmd = f"{which_ffmpeg()}"
|
| 97 |
+
cmd += " -hide_banner -loglevel panic"
|
| 98 |
+
cmd += f" -y -i {new_path}"
|
| 99 |
+
cmd += f" -acodec pcm_s16le -ac 1"
|
| 100 |
+
cmd += f' {new_path.replace(".mp4", ".wav")}'
|
| 101 |
+
subprocess.call(cmd.split())
|
| 102 |
+
return new_path
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def decode_single_video_prediction(off_logits, grid, item):
|
| 106 |
+
label = item["targets"]["offset_label"].item()
|
| 107 |
+
print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})")
|
| 108 |
+
print()
|
| 109 |
+
print("Prediction Results:")
|
| 110 |
+
off_probs = torch.softmax(off_logits, dim=-1)
|
| 111 |
+
k = min(off_probs.shape[-1], 5)
|
| 112 |
+
topk_logits, topk_preds = torch.topk(off_logits, k)
|
| 113 |
+
# remove batch dimension
|
| 114 |
+
assert len(topk_logits) == 1, "batch is larger than 1"
|
| 115 |
+
topk_logits = topk_logits[0]
|
| 116 |
+
topk_preds = topk_preds[0]
|
| 117 |
+
off_logits = off_logits[0]
|
| 118 |
+
off_probs = off_probs[0]
|
| 119 |
+
for target_hat in topk_preds:
|
| 120 |
+
print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
|
| 121 |
+
return off_probs
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main(args):
|
| 125 |
+
vfps = 25
|
| 126 |
+
afps = 16000
|
| 127 |
+
in_size = 256
|
| 128 |
+
# making the offset class grid similar to the one used in transforms,
|
| 129 |
+
# refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
|
| 130 |
+
max_off_sec = 2
|
| 131 |
+
num_cls = 21
|
| 132 |
+
|
| 133 |
+
# checking if the provided video has the correct frame rates
|
| 134 |
+
print(f"Using video: {args.vid_path}")
|
| 135 |
+
v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec")
|
| 136 |
+
_, H, W, _ = v.shape
|
| 137 |
+
if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size:
|
| 138 |
+
print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ")
|
| 139 |
+
print(f'afps: {info["audio_fps"]} -> {afps};', end=" ")
|
| 140 |
+
print(f"{(H, W)} -> min(H, W)={in_size}")
|
| 141 |
+
args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size)
|
| 142 |
+
else:
|
| 143 |
+
print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')
|
| 144 |
+
|
| 145 |
+
device = torch.device(args.device)
|
| 146 |
+
|
| 147 |
+
# load visual and audio streams
|
| 148 |
+
# rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
|
| 149 |
+
rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True)
|
| 150 |
+
|
| 151 |
+
# making an item (dict) to apply transformations
|
| 152 |
+
# NOTE: here is how it works:
|
| 153 |
+
# For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
|
| 154 |
+
# the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
|
| 155 |
+
# track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
|
| 156 |
+
# start by `offset_sec` earlier than the rgb track.
|
| 157 |
+
# It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
|
| 158 |
+
item = dict(
|
| 159 |
+
video=rgb,
|
| 160 |
+
audio=audio,
|
| 161 |
+
meta=meta,
|
| 162 |
+
path=args.vid_path,
|
| 163 |
+
split="test",
|
| 164 |
+
targets={
|
| 165 |
+
"v_start_i_sec": args.v_start_i_sec,
|
| 166 |
+
"offset_sec": args.offset_sec,
|
| 167 |
+
},
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
|
| 171 |
+
if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)):
|
| 172 |
+
print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')
|
| 173 |
+
|
| 174 |
+
# applying the test-time transform
|
| 175 |
+
item = get_test_transforms()(item)
|
| 176 |
+
|
| 177 |
+
# prepare inputs for inference
|
| 178 |
+
batch = torch.utils.data.default_collate([item])
|
| 179 |
+
aud, vid = prepare_inputs(batch, device)
|
| 180 |
+
|
| 181 |
+
# TODO:
|
| 182 |
+
# sanity check: we will take the input to the `model` and recontruct make a video from it.
|
| 183 |
+
# Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
|
| 184 |
+
# reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec,
|
| 185 |
+
# vfps, afps)
|
| 186 |
+
|
| 187 |
+
# forward pass
|
| 188 |
+
with torch.set_grad_enabled(False):
|
| 189 |
+
with torch.autocast("cuda", enabled=True):
|
| 190 |
+
_, logits = synchformer(vid, aud)
|
| 191 |
+
|
| 192 |
+
# simply prints the results of the prediction
|
| 193 |
+
decode_single_video_prediction(logits, grid, item)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
parser = argparse.ArgumentParser()
|
| 198 |
+
parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx")
|
| 199 |
+
parser.add_argument("--vid_path", required=True, help="A path to .mp4 video")
|
| 200 |
+
parser.add_argument("--offset_sec", type=float, default=0.0)
|
| 201 |
+
parser.add_argument("--v_start_i_sec", type=float, default=0.0)
|
| 202 |
+
parser.add_argument("--device", default="cuda:0")
|
| 203 |
+
args = parser.parse_args()
|
| 204 |
+
|
| 205 |
+
synchformer = Synchformer().cuda().eval()
|
| 206 |
+
synchformer.load_state_dict(
|
| 207 |
+
torch.load(
|
| 208 |
+
os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
|
| 209 |
+
weights_only=True,
|
| 210 |
+
map_location="cpu",
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
main(args)
|