Jack Wu commited on
Commit
c60109f
·
1 Parent(s): 6cf4573
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .idea/.gitignore +10 -0
  2. .idea/Generate_Audio_for_Video.iml +14 -0
  3. .idea/inspectionProfiles/Project_Default.xml +7 -0
  4. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  5. .idea/modules.xml +8 -0
  6. .idea/vcs.xml +6 -0
  7. HunyuanVideo-Foley/.gitattributes +3 -0
  8. HunyuanVideo-Foley/.gitignore +159 -0
  9. HunyuanVideo-Foley/.pre-commit-config.yaml +38 -0
  10. HunyuanVideo-Foley/DEVELOPMENT.md +187 -0
  11. HunyuanVideo-Foley/INSTALL.md +203 -0
  12. HunyuanVideo-Foley/LICENSE +77 -0
  13. HunyuanVideo-Foley/MANIFEST.in +38 -0
  14. HunyuanVideo-Foley/NOTICE +27 -0
  15. HunyuanVideo-Foley/README.md +519 -0
  16. HunyuanVideo-Foley/build_package.sh +58 -0
  17. HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml +48 -0
  18. HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml +48 -0
  19. HunyuanVideo-Foley/download_test_videos.sh +11 -0
  20. HunyuanVideo-Foley/gradio_app.py +834 -0
  21. HunyuanVideo-Foley/hunyuanvideo_foley/__init__.py +30 -0
  22. HunyuanVideo-Foley/hunyuanvideo_foley/cli.py +141 -0
  23. HunyuanVideo-Foley/hunyuanvideo_foley/constants.py +57 -0
  24. HunyuanVideo-Foley/hunyuanvideo_foley/models/__init__.py +0 -0
  25. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
  26. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
  27. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
  28. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
  29. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
  30. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
  31. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
  32. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
  33. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
  34. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
  35. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
  36. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
  37. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
  38. HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/encode.py +94 -0
  39. HunyuanVideo-Foley/hunyuanvideo_foley/models/hifi_foley.py +794 -0
  40. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/__init__.py +0 -0
  41. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/activation_layers.py +44 -0
  42. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/attn_layers.py +546 -0
  43. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/embed_layers.py +136 -0
  44. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/mlp_layers.py +149 -0
  45. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/modulate_layers.py +49 -0
  46. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/norm_layers.py +70 -0
  47. HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/posemb_layers.py +159 -0
  48. HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/__init__.py +1 -0
  49. HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/ast_model.py +289 -0
  50. HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/compute_desync_score.py +214 -0
.idea/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Ignored default folder with query files
5
+ /queries/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
9
+ # Editor-based HTTP Client requests
10
+ /httpRequests/
.idea/Generate_Audio_for_Video.iml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/.venv" />
6
+ </content>
7
+ <orderEntry type="inheritedJdk" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ <component name="PyDocumentationSettings">
11
+ <option name="format" value="PLAIN" />
12
+ <option name="myDocStringFormat" value="Plain" />
13
+ </component>
14
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="SqlNoDataSourceInspection" enabled="false" level="WARNING" enabled_by_default="false" />
5
+ <inspection_tool class="TodoComment" enabled="false" level="INFORMATION" enabled_by_default="false" />
6
+ </profile>
7
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Generate_Audio_for_Video.iml" filepath="$PROJECT_DIR$/.idea/Generate_Audio_for_Video.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
HunyuanVideo-Foley/.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text
2
+ assets/model_arch.png filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
HunyuanVideo-Foley/.gitignore ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # ==========================================
132
+ # Custom settings
133
+ # ==========================================
134
+
135
+ # For MacOS
136
+ .DS_Store
137
+
138
+ # For IDEs
139
+ .idea/
140
+ .vscode/
141
+ pyrightconfig.json
142
+ .cursorignore
143
+
144
+ assets/
145
+ examples/
146
+
147
+ # For global settings
148
+ __*/
149
+ **/my_*
150
+ tmp*.*
151
+ .my*
152
+ # Model checkpoints
153
+ *.pt
154
+ *.ckpt
155
+ *.pth
156
+ *.safetensors
157
+
158
+
159
+ CLAUDE.md
HunyuanVideo-Foley/.pre-commit-config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ - id: check-merge-conflict
10
+ - id: debug-statements
11
+ - id: check-docstring-first
12
+
13
+ - repo: https://github.com/psf/black
14
+ rev: 23.3.0
15
+ hooks:
16
+ - id: black
17
+ language_version: python3
18
+ args: [--line-length=120]
19
+
20
+ - repo: https://github.com/pycqa/isort
21
+ rev: 5.12.0
22
+ hooks:
23
+ - id: isort
24
+ args: [--profile, black, --line-length=120]
25
+
26
+ - repo: https://github.com/pycqa/flake8
27
+ rev: 6.0.0
28
+ hooks:
29
+ - id: flake8
30
+ args: [--max-line-length=120]
31
+ additional_dependencies: [flake8-docstrings]
32
+
33
+ - repo: https://github.com/pre-commit/mirrors-mypy
34
+ rev: v1.3.0
35
+ hooks:
36
+ - id: mypy
37
+ additional_dependencies: [types-all]
38
+ args: [--ignore-missing-imports]
HunyuanVideo-Foley/DEVELOPMENT.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Development Guide
2
+
3
+ This document provides guidelines for developing and contributing to the HunyuanVideo-Foley project.
4
+
5
+ ## Code Style and Quality
6
+
7
+ ### Code Formatting
8
+
9
+ We use the following tools to maintain consistent code style:
10
+
11
+ - **Black**: Code formatter with 120 character line length
12
+ - **isort**: Import sorter compatible with Black
13
+ - **flake8**: Linting and style checking
14
+ - **mypy**: Static type checking
15
+
16
+ ### Pre-commit Hooks
17
+
18
+ Install pre-commit hooks to automatically format code before commits:
19
+
20
+ ```bash
21
+ pip install pre-commit
22
+ pre-commit install
23
+ ```
24
+
25
+ ### Manual Code Formatting
26
+
27
+ Format code manually:
28
+
29
+ ```bash
30
+ # Format all Python files
31
+ black --line-length 120 .
32
+
33
+ # Sort imports
34
+ isort --profile black --line-length 120 .
35
+
36
+ # Check code style
37
+ flake8 --max-line-length 120
38
+
39
+ # Type checking
40
+ mypy --ignore-missing-imports .
41
+ ```
42
+
43
+ ## Project Structure
44
+
45
+ ```
46
+ hunyuanvideo_foley/
47
+ ├── models/ # Model implementations
48
+ │ ├── hifi_foley.py # Main model
49
+ │ ├── nn/ # Neural network layers
50
+ │ ├── dac_vae/ # Audio VAE
51
+ │ └── synchformer/ # Synchronization model
52
+ ├── utils/ # Utilities
53
+ │ ├── config_utils.py # Configuration handling
54
+ │ ├── feature_utils.py # Feature extraction
55
+ │ ├── model_utils.py # Model loading/saving
56
+ │ └── media_utils.py # Audio/video processing
57
+ └── constants.py # Project constants
58
+ ```
59
+
60
+ ## Coding Standards
61
+
62
+ ### Error Handling
63
+
64
+ - Use custom exceptions for domain-specific errors
65
+ - Always validate inputs at function boundaries
66
+ - Log errors with appropriate levels (ERROR, WARNING, INFO)
67
+ - Provide helpful error messages to users
68
+
69
+ ### Type Hints
70
+
71
+ - Add type hints to all function parameters and return values
72
+ - Use `Optional[Type]` for nullable parameters
73
+ - Import types from `typing` module
74
+
75
+ ### Documentation
76
+
77
+ - Add docstrings to all public functions and classes
78
+ - Use Google-style docstrings
79
+ - Document parameters, return values, and exceptions
80
+
81
+ ### Example Function
82
+
83
+ ```python
84
+ def process_video(
85
+ video_path: str,
86
+ max_duration: Optional[float] = None
87
+ ) -> Tuple[np.ndarray, float]:
88
+ """
89
+ Process video file and extract frames.
90
+
91
+ Args:
92
+ video_path: Path to input video file
93
+ max_duration: Maximum duration in seconds (optional)
94
+
95
+ Returns:
96
+ Tuple of (frames array, duration in seconds)
97
+
98
+ Raises:
99
+ FileNotFoundError: If video file doesn't exist
100
+ VideoProcessingError: If video processing fails
101
+ """
102
+ if not os.path.exists(video_path):
103
+ raise FileNotFoundError(f"Video file not found: {video_path}")
104
+
105
+ # Implementation here...
106
+ ```
107
+
108
+ ## Testing
109
+
110
+ ### Running Tests
111
+
112
+ ```bash
113
+ # Run all tests
114
+ python -m pytest
115
+
116
+ # Run specific test file
117
+ python -m pytest tests/test_feature_utils.py
118
+
119
+ # Run with coverage
120
+ python -m pytest --cov=hunyuanvideo_foley
121
+ ```
122
+
123
+ ### Writing Tests
124
+
125
+ - Place tests in `tests/` directory
126
+ - Name test files as `test_*.py`
127
+ - Use descriptive test function names
128
+ - Test edge cases and error conditions
129
+
130
+ ## Development Workflow
131
+
132
+ 1. **Setup Environment**
133
+ ```bash
134
+ python -m venv venv
135
+ source venv/bin/activate # Linux/Mac
136
+ # or
137
+ venv\Scripts\activate # Windows
138
+
139
+ pip install -r requirements.txt
140
+ pip install -e .
141
+ ```
142
+
143
+ 2. **Install Development Tools**
144
+ ```bash
145
+ pre-commit install
146
+ ```
147
+
148
+ 3. **Make Changes**
149
+ - Follow the coding standards above
150
+ - Add tests for new functionality
151
+ - Update documentation as needed
152
+
153
+ 4. **Run Quality Checks**
154
+ ```bash
155
+ black --check --line-length 120 .
156
+ isort --check-only --profile black .
157
+ flake8 --max-line-length 120
158
+ mypy --ignore-missing-imports .
159
+ pytest
160
+ ```
161
+
162
+ 5. **Commit Changes**
163
+ ```bash
164
+ git add .
165
+ git commit -m "feat: add new feature"
166
+ ```
167
+
168
+ ## Performance Considerations
169
+
170
+ - Use `torch.no_grad()` for inference-only code
171
+ - Leverage GPU when available
172
+ - Implement batch processing where possible
173
+ - Profile code to identify bottlenecks
174
+
175
+ ## Dependencies
176
+
177
+ - Keep dependencies minimal and well-maintained
178
+ - Pin versions for reproducibility
179
+ - Separate development dependencies from runtime dependencies
180
+ - Document any special installation requirements
181
+
182
+ ## Configuration
183
+
184
+ - Use centralized configuration in `constants.py`
185
+ - Support environment variable overrides
186
+ - Provide sensible defaults for all parameters
187
+ - Validate configuration at startup
HunyuanVideo-Foley/INSTALL.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 安装指南 - HunyuanVideo-Foley
2
+
3
+ 本文档提供了将 HunyuanVideo-Foley 作为 Python 包安装和使用的详细指南。
4
+
5
+ ## 安装方式
6
+
7
+ ### 方式1:从源码安装(推荐)
8
+
9
+ ```bash
10
+ # 克隆仓库
11
+ git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
12
+ cd HunyuanVideo-Foley
13
+
14
+ # 安装包(开发模式)
15
+ pip install -e .
16
+
17
+ # 或安装包含所有可选依赖
18
+ pip install -e .[all]
19
+ ```
20
+
21
+ ### 方式2:直接从GitHub安装
22
+
23
+ ```bash
24
+ pip install git+https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley.git
25
+ ```
26
+
27
+ ### 方式3:构建wheel包安装
28
+
29
+ ```bash
30
+ # 在项目根目录下
31
+ python setup.py bdist_wheel
32
+ pip install dist/hunyuanvideo_foley-1.0.0-py3-none-any.whl
33
+ ```
34
+
35
+ ## 特殊依赖安装
36
+
37
+ 由于某些依赖不在PyPI上,需要单独安装:
38
+
39
+ ```bash
40
+ # 安装audiotools(必需)
41
+ pip install git+https://github.com/descriptinc/audiotools
42
+
43
+ # 安装特定版本的transformers(支持SigLIP2)
44
+ pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2
45
+ ```
46
+
47
+ ## 可选依赖安装
48
+
49
+ ```bash
50
+ # 安装开发依赖
51
+ pip install hunyuanvideo-foley[dev]
52
+
53
+ # 安装测试依赖
54
+ pip install hunyuanvideo-foley[test]
55
+
56
+ # 安装Gradio界面依赖
57
+ pip install hunyuanvideo-foley[gradio]
58
+
59
+ # 安装所有可选依赖
60
+ pip install hunyuanvideo-foley[all]
61
+ ```
62
+
63
+ ## 验证安装
64
+
65
+ ```bash
66
+ # 检查包是否正确安装
67
+ python -c "import hunyuanvideo_foley; print(hunyuanvideo_foley.__version__)"
68
+
69
+ # 检查命令行工具
70
+ hunyuanvideo-foley --help
71
+ ```
72
+
73
+ ## 使用方法
74
+
75
+ ### 1. 作为Python包使用
76
+
77
+ ```python
78
+ import hunyuanvideo_foley as hvf
79
+
80
+ # 加载模型
81
+ model_dict, cfg = hvf.load_model(
82
+ model_path="path/to/model",
83
+ config_path="configs/hunyuanvideo-foley-xxl.yaml"
84
+ )
85
+
86
+ # 处理特征
87
+ visual_feats, text_feats, audio_len = hvf.feature_process(
88
+ video_path="video.mp4",
89
+ prompt="footsteps on gravel",
90
+ model_dict=model_dict,
91
+ cfg=cfg
92
+ )
93
+
94
+ # 生成音频
95
+ audio, sample_rate = hvf.denoise_process(
96
+ visual_feats, text_feats, audio_len,
97
+ model_dict, cfg
98
+ )
99
+ ```
100
+
101
+ ### 2. 使用命令行工具
102
+
103
+ ```bash
104
+ # 单个视频处理
105
+ hunyuanvideo-foley \
106
+ --model_path ./pretrained_models \
107
+ --single_video video.mp4 \
108
+ --single_prompt "footsteps on gravel" \
109
+ --output_dir ./outputs
110
+
111
+ # 批量处理
112
+ hunyuanvideo-foley \
113
+ --model_path ./pretrained_models \
114
+ --csv_path batch_videos.csv \
115
+ --output_dir ./outputs
116
+
117
+ # 启动Gradio界面
118
+ hunyuanvideo-foley --gradio --model_path ./pretrained_models
119
+ ```
120
+
121
+ ### 3. 使用原始脚本(向后兼容)
122
+
123
+ ```bash
124
+ # 使用原始infer.py脚本
125
+ python infer.py --model_path ./pretrained_models --single_video video.mp4 --single_prompt "audio description"
126
+
127
+ # 启动Gradio应用
128
+ export HIFI_FOLEY_MODEL_PATH=./pretrained_models
129
+ python gradio_app.py
130
+ ```
131
+
132
+ ## 开发环境设置
133
+
134
+ 如果你想参与开发:
135
+
136
+ ```bash
137
+ # 克隆项目
138
+ git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
139
+ cd HunyuanVideo-Foley
140
+
141
+ # 安装开发版本
142
+ pip install -e .[dev]
143
+
144
+ # 安装pre-commit钩子
145
+ pre-commit install
146
+
147
+ # 运行测试
148
+ python -m pytest
149
+
150
+ # 代码格式化
151
+ black --line-length 120 .
152
+ isort --profile black .
153
+
154
+ # 类型检查
155
+ mypy --ignore-missing-imports .
156
+ ```
157
+
158
+ ## 系统要求
159
+
160
+ - **Python**: 3.8+
161
+ - **操作系统**: Linux(主要支持),macOS,Windows
162
+ - **GPU内存**: 推荐 ≥24GB VRAM(如RTX 3090/4090)
163
+ - **CUDA版本**: 12.4 或 11.8(推荐)
164
+
165
+ ## 故障排除
166
+
167
+ ### 常见问题
168
+
169
+ 1. **ImportError: No module named 'audiotools'**
170
+ ```bash
171
+ pip install git+https://github.com/descriptinc/audiotools
172
+ ```
173
+
174
+ 2. **CUDA内存不足**
175
+ - 使用较小的批次大小
176
+ - 确保GPU有足够的VRAM(推荐24GB+)
177
+
178
+ 3. **transformers版本问题**
179
+ ```bash
180
+ pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2
181
+ ```
182
+
183
+ ### 获取帮助
184
+
185
+ - 查看项目README: [GitHub](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley)
186
+ - 报告问题: [GitHub Issues](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley/issues)
187
+ - 论文: [arXiv:2508.16930](https://arxiv.org/abs/2508.16930)
188
+
189
+ ## 模型下载
190
+
191
+ ```bash
192
+ # 使用HuggingFace Hub
193
+ git clone https://huggingface.co/tencent/HunyuanVideo-Foley
194
+
195
+ # 或使用huggingface-cli
196
+ huggingface-cli download tencent/HunyuanVideo-Foley
197
+ ```
198
+
199
+ ## 配置文件
200
+
201
+ 包安装后,配置文件位于:
202
+ - `hunyuanvideo_foley/configs/` 目录
203
+ - 默认配置:`configs/hunyuanvideo-foley-xxl.yaml`
HunyuanVideo-Foley/LICENSE ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent HunyuanVideo-Foley Release Date: August 28, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
HunyuanVideo-Foley/MANIFEST.in ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Include package metadata and documentation
2
+ include README.md
3
+ include LICENSE
4
+ include NOTICE
5
+ include DEVELOPMENT.md
6
+ include CLAUDE.md
7
+ include requirements.txt
8
+ include pyproject.toml
9
+ include pytest.ini
10
+
11
+ # Include configuration files
12
+ include configs/*.yaml
13
+ include configs/*.yml
14
+ recursive-include hunyuanvideo_foley/configs *.yaml *.yml
15
+
16
+ # Include test assets if any
17
+ include assets/*.csv
18
+ include assets/*.txt
19
+ recursive-include assets/test_videos *
20
+
21
+ # Include example scripts
22
+ include *.py
23
+ include *.sh
24
+
25
+ # Include test files
26
+ recursive-include tests *.py
27
+
28
+ # Exclude unnecessary files
29
+ global-exclude *.pyc
30
+ global-exclude *.pyo
31
+ global-exclude *~
32
+ global-exclude .DS_Store
33
+ global-exclude __pycache__
34
+ prune .git
35
+ prune .github
36
+ prune examples/*/outputs
37
+ prune **/__pycache__
38
+ prune **/*.pyc
HunyuanVideo-Foley/NOTICE ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage and Legal Notices:
2
+
3
+ Tencent is pleased to support the open source community by making Tencent HunyuanVideo-Foley available.
4
+
5
+ Copyright (C) 2025 Tencent. All rights reserved.
6
+
7
+ Tencent HunyuanVideo-Foley is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent HunyuanVideo-Foley does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
8
+
9
+ For avoidance of doubts, Tencent HunyuanVideo-Foley means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
10
+
11
+
12
+ Other dependencies and licenses:
13
+
14
+
15
+ Open Source Software Licensed under the MIT License:
16
+ --------------------------------------------------------------------
17
+ 1. syncformer
18
+ Copyright (c) 2024 Vladimir Iashin
19
+
20
+
21
+ Terms of the MIT License:
22
+ --------------------------------------------------------------------
23
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
24
+
25
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
26
+
27
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
HunyuanVideo-Foley/README.md ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
4
+
5
+ <img src="assets/logo.png" alt="HunyuanVideo-Foley Logo" width="400">
6
+
7
+ <h4>Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation</h4>
8
+
9
+ <p align="center">
10
+ <strong>Professional-grade AI sound effect generation for video content creators</strong>
11
+ </p>
12
+
13
+ <div align="center">
14
+ <a href=https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/Code-black.svg?logo=github height=22px></a>
15
+ <a href=https://szczesnys.github.io/hunyuanvideo-foley target="_blank"><img src=https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
16
+ <a href=https://huggingface.co/tencent/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
17
+ <a href=https://huggingface.co/spaces/tencent/HunyuanVideo-Foley target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Demo-276cb4.svg height=22px></a>
18
+ <a href=https://arxiv.org/abs/2508.16930 target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
19
+ <a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
20
+ <a href=https://discord.gg/YEyGGn6Bte target="_blank"><img src=https://img.shields.io/badge/Hunyuan-141984.svg?logo=discord height=22px></a>
21
+ </div>
22
+
23
+ </div>
24
+
25
+ ---
26
+
27
+ <div align="center">
28
+
29
+ ### 👥 **Authors**
30
+
31
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 15px; margin: 20px 0;">
32
+
33
+ **Sizhe Shan**<sup>1,2*</sup> • **Qiulin Li**<sup>1,3*</sup> • **Yutao Cui**<sup>1</sup> • **Miles Yang**<sup>1</sup> • **Yuehai Wang**<sup>2</sup> • **Qun Yang**<sup>3</sup> • **Jin Zhou**<sup>1†</sup> • **Zhao Zhong**<sup>1</sup>
34
+
35
+ </div>
36
+
37
+ <div style="margin-top: 15px; font-size: 14px; color: #666;">
38
+
39
+ 🏢 <sup>1</sup>**Tencent Hunyuan** • 🎓 <sup>2</sup>**Zhejiang University** • ✈️ <sup>3</sup>**Nanjing University of Aeronautics and Astronautics**
40
+
41
+ *Equal contribution • †Project lead
42
+
43
+ </div>
44
+
45
+ </div>
46
+
47
+
48
+ ---
49
+
50
+ ## 🔥🔥🔥 **News**
51
+
52
+ <div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); padding: 20px; border-radius: 15px; margin: 20px 0; border-left: 5px solid #2196f3;">
53
+
54
+ - **[2025.9.29]** 🚀 **HunyuanVideo-Foley-XL Model Release** - Release XL-sized model with offload inference support, significantly reducing VRAM requirements.
55
+ - **[2025.8.28]** 🌟 **HunyuanVideo-Foley Open Source Release** - Inference code and model weights publicly available.
56
+
57
+ </div>
58
+
59
+ ---
60
+
61
+ ## 🎥 **Demo & Showcase**
62
+
63
+ <div align="center">
64
+
65
+ > **Experience the magic of AI-generated Foley audio in perfect sync with video content!**
66
+
67
+ <div style="border: 3px solid #4A90E2; border-radius: 15px; padding: 10px; margin: 20px 0; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);">
68
+
69
+ <video src="https://github.com/user-attachments/assets/d6e1b6fd-6980-4a68-8717-74298d064195" width="80%" controls style="border-radius: 10px; box-shadow: 0 8px 32px rgba(0,0,0,0.1);"> </video>
70
+
71
+ <p><em>🎬 Watch how HunyuanVideo-Foley generates immersive sound effects synchronized with video content</em></p>
72
+
73
+ </div>
74
+
75
+ ---
76
+
77
+ ## 🤝 **Community Contributions**
78
+
79
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #28a745; margin: 20px 0; color: #333;">
80
+
81
+ **ComfyUI Integration** - Thanks to the amazing community for creating ComfyUI nodes:
82
+
83
+ - **[if-ai/ComfyUI_HunyuanVideoFoley](https://github.com/if-ai/ComfyUI_HunyuanVideoFoley)** - ComfyUI workflow integration which supports cpu offloading and FP8 quantization
84
+ - **[phazei/ComfyUI-HunyuanVideo-Foley](https://github.com/phazei/ComfyUI-HunyuanVideo-Foley)** - Alternative ComfyUI node implementation which supports different precision modes
85
+
86
+ </div>
87
+
88
+ <div align="center" style="margin: 20px 0;">
89
+
90
+ **🌟 We encourage and appreciate community contributions that make HunyuanVideo-Foley more accessible!**
91
+
92
+ </div>
93
+
94
+ ---
95
+ ### ✨ **Key Highlights**
96
+
97
+ <table align="center" style="border: none; margin: 20px 0;">
98
+ <tr>
99
+ <td align="center" width="33%">
100
+
101
+ 🎭 **Multi-scenario Sync**
102
+ High-quality audio synchronized with complex video scenes
103
+
104
+ </td>
105
+ <td align="center" width="33%">
106
+
107
+ 🧠 **Multi-modal Balance**
108
+ Perfect harmony between visual and textual information
109
+
110
+ </td>
111
+ <td align="center" width="33%">
112
+
113
+ 🎵 **48kHz Hi-Fi Output**
114
+ Professional-grade audio generation with crystal clarity
115
+
116
+ </td>
117
+ </tr>
118
+ </table>
119
+
120
+ </div>
121
+
122
+ ---
123
+
124
+ ## 📄 **Abstract**
125
+
126
+ <div align="center" style="background: linear-gradient(135deg, #ffeef8 0%, #f0f8ff 100%); padding: 30px; border-radius: 20px; margin: 20px 0; border-left: 5px solid #ff6b9d; color: #333;">
127
+
128
+ **🚀 Tencent Hunyuan** open-sources **HunyuanVideo-Foley** an end-to-end video sound effect generation model!
129
+
130
+ *A professional-grade AI tool specifically designed for video content creators, widely applicable to diverse scenarios including short video creation, film production, advertising creativity, and game development.*
131
+
132
+ </div>
133
+
134
+ ### 🎯 **Core Highlights**
135
+
136
+ <div style="display: grid; grid-template-columns: 1fr; gap: 15px; margin: 20px 0;">
137
+
138
+ <div style="border-left: 4px solid #4CAF50; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
139
+
140
+ **🎬 Multi-scenario Audio-Visual Synchronization**
141
+ Supports generating high-quality audio that is synchronized and semantically aligned with complex video scenes, enhancing realism and immersive experience for film/TV and gaming applications.
142
+
143
+ </div>
144
+
145
+ <div style="border-left: 4px solid #2196F3; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
146
+
147
+ **⚖️ Multi-modal Semantic Balance**
148
+ Intelligently balances visual and textual information analysis, comprehensively orchestrates sound effect elements, avoids one-sided generation, and meets personalized dubbing requirements.
149
+
150
+ </div>
151
+
152
+ <div style="border-left: 4px solid #FF9800; padding: 15px; background: #f8f9fa; border-radius: 8px; color: #333;">
153
+
154
+ **🎵 High-fidelity Audio Output**
155
+ Self-developed 48kHz audio VAE perfectly reconstructs sound effects, music, and vocals, achieving professional-grade audio generation quality.
156
+
157
+ </div>
158
+
159
+ </div>
160
+
161
+ <div align="center" style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0; color: #333;">
162
+
163
+ **🏆 SOTA Performance Achieved**
164
+
165
+ *HunyuanVideo-Foley comprehensively leads the field across multiple evaluation benchmarks, achieving new state-of-the-art levels in audio fidelity, visual-semantic alignment, temporal alignment, and distribution matching - surpassing all open-source solutions!*
166
+
167
+ </div>
168
+
169
+ <div align="center">
170
+
171
+ ![Performance Overview](assets/pan_chart.png)
172
+ *📊 Performance comparison across different evaluation metrics - HunyuanVideo-Foley leads in all categories*
173
+
174
+ </div>
175
+
176
+ ---
177
+
178
+ ## 🔧 **Technical Architecture**
179
+
180
+ ### 📊 **Data Pipeline Design**
181
+
182
+ <div align="center" style="margin: 20px 0; color: #333;">
183
+
184
+ ![Data Pipeline](assets/data_pipeline.png)
185
+ *🔄 Comprehensive data processing pipeline for high-quality text-video-audio datasets*
186
+
187
+ </div>
188
+
189
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #17a2b8; margin: 20px 0;">
190
+
191
+ The **TV2A (Text-Video-to-Audio)** task presents a complex multimodal generation challenge requiring large-scale, high-quality datasets. Our comprehensive data pipeline systematically identifies and excludes unsuitable content to produce robust and generalizable audio generation capabilities.
192
+
193
+ </div>
194
+
195
+ ### 🏗️ **Model Architecture**
196
+
197
+ <div align="center" style="margin: 20px 0; color: #333;">
198
+
199
+ ![Model Architecture](assets/model_arch.png)
200
+ *🧠 HunyuanVideo-Foley hybrid architecture with multimodal and unimodal transformer blocks*
201
+
202
+ </div>
203
+
204
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #28a745; margin: 20px 0;">
205
+
206
+ **HunyuanVideo-Foley** employs a sophisticated hybrid architecture:
207
+
208
+ - **🔄 Multimodal Transformer Blocks**: Process visual-audio streams simultaneously
209
+ - **🎵 Unimodal Transformer Blocks**: Focus on audio stream refinement
210
+ - **👁️ Visual Encoding**: Pre-trained encoder extracts visual features from video frames
211
+ - **📝 Text Processing**: Semantic features extracted via pre-trained text encoder
212
+ - **🎧 Audio Encoding**: Latent representations with Gaussian noise perturbation
213
+ - **⏰ Temporal Alignment**: Synchformer-based frame-level synchronization with gated modulation
214
+
215
+ </div>
216
+
217
+ ---
218
+
219
+ ## 📈 **Performance Benchmarks**
220
+
221
+ ### 🎬 **MovieGen-Audio-Bench Results**
222
+
223
+ <div align="center">
224
+
225
+ > *Objective and Subjective evaluation results demonstrating superior performance across all metrics*
226
+
227
+ </div>
228
+
229
+ <div style="overflow-x: auto; margin: 20px 0;">
230
+
231
+ | 🏆 **Method** | **PQ** ↑ | **PC** ↓ | **CE** ↑ | **CU** ↑ | **IB** ↑ | **DeSync** ↓ | **CLAP** ↑ | **MOS-Q** ↑ | **MOS-S** ↑ | **MOS-T** ↑ |
232
+ |:-------------:|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------:|:-----------:|:------------:|:------------:|:------------:|
233
+ | FoleyGrafter | 6.27 | 2.72 | 3.34 | 5.68 | 0.17 | 1.29 | 0.14 | 3.36±0.78 | 3.54±0.88 | 3.46±0.95 |
234
+ | V-AURA | 5.82 | 4.30 | 3.63 | 5.11 | 0.23 | 1.38 | 0.14 | 2.55±0.97 | 2.60±1.20 | 2.70±1.37 |
235
+ | Frieren | 5.71 | 2.81 | 3.47 | 5.31 | 0.18 | 1.39 | 0.16 | 2.92±0.95 | 2.76±1.20 | 2.94±1.26 |
236
+ | MMAudio | 6.17 | 2.84 | 3.59 | 5.62 | 0.27 | 0.80 | 0.35 | 3.58±0.84 | 3.63±1.00 | 3.47±1.03 |
237
+ | ThinkSound | 6.04 | 3.73 | 3.81 | 5.59 | 0.18 | 0.91 | 0.20 | 3.20±0.97 | 3.01±1.04 | 3.02±1.08 |
238
+ | **HunyuanVideo-Foley (ours)** | **6.59** | **2.74** | **3.88** | **6.13** | **0.35** | **0.74** | **0.33** | **4.14±0.68** | **4.12±0.77** | **4.15±0.75** |
239
+
240
+ </div>
241
+
242
+
243
+ ### 🎯 **Kling-Audio-Eval Results**
244
+
245
+ <div align="center">
246
+
247
+ > *Comprehensive objective evaluation showcasing state-of-the-art performance*
248
+
249
+ </div>
250
+
251
+ <div style="overflow-x: auto; margin: 20px 0;">
252
+
253
+ | 🏆 **Method** | **FD_PANNs** ↓ | **FD_PASST** ↓ | **KL** ↓ | **IS** ↑ | **PQ** ↑ | **PC** ↓ | **CE** ↑ | **CU** ↑ | **IB** ↑ | **DeSync** ↓ | **CLAP** ↑ |
254
+ |:-------------:|:--------------:|:--------------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------:|:-----------:|
255
+ | FoleyGrafter | 22.30 | 322.63 | 2.47 | 7.08 | 6.05 | 2.91 | 3.28 | 5.44 | 0.22 | 1.23 | 0.22 |
256
+ | V-AURA | 33.15 | 474.56 | 3.24 | 5.80 | 5.69 | 3.98 | 3.13 | 4.83 | 0.25 | 0.86 | 0.13 |
257
+ | Frieren | 16.86 | 293.57 | 2.95 | 7.32 | 5.72 | 2.55 | 2.88 | 5.10 | 0.21 | 0.86 | 0.16 |
258
+ | MMAudio | 9.01 | 205.85 | 2.17 | 9.59 | 5.94 | 2.91 | 3.30 | 5.39 | 0.30 | 0.56 | 0.27 |
259
+ | ThinkSound | 9.92 | 228.68 | 2.39 | 6.86 | 5.78 | 3.23 | 3.12 | 5.11 | 0.22 | 0.67 | 0.22 |
260
+ | **HunyuanVideo-Foley (ours)** | **6.07** | **202.12** | **1.89** | **8.30** | **6.12** | **2.76** | **3.22** | **5.53** | **0.38** | **0.54** | **0.24** |
261
+
262
+ </div>
263
+
264
+ <div align="center" style="background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%); color: white; padding: 15px; border-radius: 10px; margin: 20px 0; color: #333;">
265
+
266
+ **🎉 Outstanding Results!** HunyuanVideo-Foley achieves the best scores across **ALL** evaluation metrics, demonstrating significant improvements in audio quality, synchronization, and semantic alignment.
267
+
268
+ </div>
269
+
270
+
271
+
272
+ ---
273
+
274
+ ## 🚀 **Quick Start**
275
+
276
+ ### 📦 **Installation**
277
+
278
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0; color: #333;">
279
+
280
+ **🔧 System Requirements**
281
+ - **CUDA**: 12.4 or 11.8 recommended
282
+ - **Python**: 3.8+
283
+ - **OS**: Linux (primary support)
284
+ - **VRAM**: 20GB for XXL model (or 12GB with `--enable_offload`), 16GB for XL model (or 8GB with `--enable_offload`)
285
+
286
+ </div>
287
+
288
+ #### **Step 1: Clone Repository**
289
+
290
+ ```bash
291
+ # 📥 Clone the repository
292
+ git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley
293
+ cd HunyuanVideo-Foley
294
+ ```
295
+
296
+ #### **Step 2: Environment Setup**
297
+
298
+ <div style="background: #fff3cd; padding: 15px; border-radius: 8px; border-left: 4px solid #ffc107; margin: 10px 0; color: #333;">
299
+
300
+ 💡 **Tip**: We recommend using [Conda](https://docs.anaconda.com/free/miniconda/index.html) for Python environment management.
301
+
302
+ </div>
303
+
304
+ ```bash
305
+ # 🔧 Install dependencies
306
+ pip install -r requirements.txt
307
+ ```
308
+
309
+ #### **Step 3: Download Pretrained Models**
310
+
311
+ <div style="background: #d1ecf1; padding: 15px; border-radius: 8px; border-left: 4px solid #17a2b8; margin: 10px 0;color: #333;">
312
+
313
+ 🔗 **Download Model weights from Huggingface**
314
+ ```bash
315
+ # using git-lfs
316
+ git clone https://huggingface.co/tencent/HunyuanVideo-Foley
317
+
318
+ # using huggingface-cli
319
+ huggingface-cli download tencent/HunyuanVideo-Foley
320
+ ```
321
+
322
+ <!-- 🔗 **Download Model weights from ModelScope** -->
323
+ <!-- ```bash -->
324
+ <!-- # using git-lfs -->
325
+ <!-- git clone https://huggingface.co/tencent/HunyuanVideo-Foley -->
326
+ <!-- -->
327
+ <!-- # using huggingface-cli -->
328
+ <!-- huggingface-cli download tencent/HunyuanVideo-Foley -->
329
+ <!-- ``` -->
330
+
331
+ </div>
332
+
333
+
334
+ ---
335
+
336
+ ## 💻 **Usage**
337
+
338
+ ### 📊 **Model Specifications**
339
+
340
+ | Model | Checkpoint | VRAM (Normal) | VRAM (Offload) |
341
+ |-------|------------|---------------|----------------|
342
+ | **XXL** *(Default)* | `hunyuanvideo_foley.pth` | 20GB | 12GB |
343
+ | **XL** | `hunyuanvideo_foley_xl.pth` | 16GB | 8GB |
344
+
345
+ ### 🎬 **Single Video Generation**
346
+
347
+ <div style="background: #e8f5e8; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745; margin: 10px 0;color: #333;">
348
+
349
+ Generate Foley audio for a single video file with text description:
350
+
351
+ </div>
352
+
353
+ ```bash
354
+ # Use XXL model (default, best quality)
355
+ python3 infer.py \
356
+ --model_path PRETRAINED_MODEL_PATH_DIR \
357
+ --single_video video_path \
358
+ --single_prompt "audio description" \
359
+ --output_dir OUTPUT_DIR \
360
+ # --enable_offload
361
+
362
+ # Use XL model (memory-friendly)
363
+ python3 infer.py \
364
+ --model_path PRETRAINED_MODEL_PATH_DIR \
365
+ --model_size xl \
366
+ --single_video video_path \
367
+ --single_prompt "audio description" \
368
+ --output_dir OUTPUT_DIR \
369
+ # --enable_offload
370
+ ```
371
+
372
+ ### 📂 **Batch Processing**
373
+
374
+ <div style="background: #fff3e0; padding: 15px; border-radius: 8px; border-left: 4px solid #ff9800; margin: 10px 0;color: #333;">
375
+
376
+ Process multiple videos using a CSV file with video paths and descriptions:
377
+
378
+ </div>
379
+
380
+ ```bash
381
+ # Download sample test videos
382
+ bash ./download_test_videos.sh
383
+
384
+ # Batch processing
385
+ python3 infer.py \
386
+ --model_path PRETRAINED_MODEL_PATH_DIR \
387
+ --csv_path assets/test.csv \
388
+ --output_dir OUTPUT_DIR \
389
+ # --enable_offload
390
+ ```
391
+
392
+ ### 🌐 **Interactive Web Interface**
393
+
394
+ <div style="background: #f3e5f5; padding: 15px; border-radius: 8px; border-left: 4px solid #9c27b0; margin: 10px 0;color: #333;">
395
+
396
+ Launch a user-friendly Gradio web interface for easy interaction:
397
+
398
+ </div>
399
+
400
+ ```bash
401
+ # Launch with XXL model (default)
402
+ export HIFI_FOLEY_MODEL_PATH=PRETRAINED_MODEL_PATH_DIR
403
+ python3 gradio_app.py
404
+
405
+ # Launch with XL model (memory-friendly)
406
+ export HIFI_FOLEY_MODEL_PATH=PRETRAINED_MODEL_PATH_DIR
407
+ MODEL_SIZE=xl python3 gradio_app.py
408
+
409
+ # Optional: Enable offload to reduce memory usage
410
+ ENABLE_OFFLOAD=true python3 gradio_app.py
411
+ ```
412
+
413
+ <div align="center" style="margin: 20px 0; color: #333;">
414
+
415
+ *🚀 Then open your browser and navigate to the provided local URL to start generating Foley audio!*
416
+
417
+ </div>
418
+
419
+ ---
420
+
421
+ ## 📚 **Citation**
422
+
423
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 4px solid #6c757d; margin: 20px 0; color: #333;">
424
+
425
+ If you find **HunyuanVideo-Foley** useful for your research, please consider citing our paper:
426
+
427
+ </div>
428
+
429
+ ```bibtex
430
+ @misc{shan2025hunyuanvideofoleymultimodaldiffusionrepresentation,
431
+ title={HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation},
432
+ author={Sizhe Shan and Qiulin Li and Yutao Cui and Miles Yang and Yuehai Wang and Qun Yang and Jin Zhou and Zhao Zhong},
433
+ year={2025},
434
+ eprint={2508.16930},
435
+ archivePrefix={arXiv},
436
+ primaryClass={eess.AS},
437
+ url={https://arxiv.org/abs/2508.16930},
438
+ }
439
+ ```
440
+ ## Star History
441
+
442
+ [![Star History Chart](https://api.star-history.com/svg?repos=Tencent-Hunyuan/HunyuanVideo-Foley&type=Date)](https://www.star-history.com/#Tencent-Hunyuan/HunyuanVideo-Foley&Date)
443
+ ---
444
+
445
+ ## 🙏 **Acknowledgements**
446
+
447
+ <div align="center">
448
+
449
+ **We extend our heartfelt gratitude to the open-source community!**
450
+
451
+ </div>
452
+
453
+ <table align="center" style="width: 100%; border: none; margin: 20px 0;">
454
+ <tr>
455
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
456
+
457
+ 🎨 **[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)**
458
+ *Foundation diffusion models*
459
+
460
+ </td>
461
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
462
+
463
+ ⚡ **[FLUX](https://github.com/black-forest-labs/flux)**
464
+ *Advanced generation techniques*
465
+
466
+ </td>
467
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
468
+
469
+ 🎵 **[MMAudio](https://github.com/hkchengrex/MMAudio)**
470
+ *Multimodal audio generation*
471
+
472
+ </td>
473
+ </tr>
474
+ <tr>
475
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
476
+
477
+ 🤗 **[HuggingFace](https://huggingface.co)**
478
+ *Platform & diffusers library*
479
+
480
+ </td>
481
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
482
+
483
+ 🗜️ **[DAC](https://github.com/descriptinc/descript-audio-codec)**
484
+ *High-Fidelity Audio Compression*
485
+
486
+ </td>
487
+ <td align="center" style="width: 33%; padding: 10px; vertical-align: top;">
488
+
489
+ 🔗 **[Synchformer](https://github.com/v-iashin/Synchformer)**
490
+ *Audio-Visual Synchronization*
491
+
492
+ </td>
493
+ </tr>
494
+ </table>
495
+
496
+ <div align="center" style="background: linear-gradient(135deg, #74b9ff 0%, #0984e3 100%); color: white; padding: 20px; border-radius: 15px; margin: 20px 0;, color: #333;">
497
+
498
+ **🌟 Special thanks to all researchers and developers who contribute to the advancement of AI-generated audio and multimodal learning!**
499
+
500
+ </div>
501
+
502
+
503
+ ---
504
+
505
+ <div align="center" style="margin: 30px 0;">
506
+
507
+ ### 🔗 **Connect with Us**
508
+
509
+ [![GitHub](https://img.shields.io/badge/GitHub-Follow-black?style=for-the-badge&logo=github)](https://github.com/Tencent-Hunyuan)
510
+ [![Twitter](https://img.shields.io/badge/Twitter-Follow-blue?style=for-the-badge&logo=twitter)](https://twitter.com/Tencent)
511
+ [![Hunyuan](https://img.shields.io/badge/Website-HunyuanAI-green?style=for-the-badge&logo=hunyuan)](https://hunyuan.tencent.com/)
512
+
513
+ <p style="color: #666; margin-top: 15px; font-size: 14px;">
514
+
515
+ © 2025 Tencent Hunyuan. All rights reserved. | Made with ❤️ for the AI community
516
+
517
+ </p>
518
+
519
+ </div>
HunyuanVideo-Foley/build_package.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 构建 HunyuanVideo-Foley Python 包的脚本
3
+
4
+ set -e # 出现错误时退出
5
+
6
+ echo "🚀 开始构建 HunyuanVideo-Foley Python 包..."
7
+
8
+ # 清理之前的构建文件
9
+ echo "🧹 清理之前的构建文件..."
10
+ rm -rf build/ dist/ *.egg-info/
11
+
12
+ # 检查必要的工具
13
+ echo "🔍 检查构建工具..."
14
+ python -c "import setuptools, wheel; print('✅ setuptools和wheel已安装')" || {
15
+ echo "❌ 请安装构建工具: pip install setuptools wheel"
16
+ exit 1
17
+ }
18
+
19
+ # 检查setup.py
20
+ echo "🔍 验证setup.py配置..."
21
+ python setup.py check --restructuredtext --strict || {
22
+ echo "⚠️ setup.py验证有警告,但继续构建..."
23
+ }
24
+
25
+ # 构建源码分发包
26
+ echo "📦 构建源码分发包..."
27
+ python setup.py sdist
28
+
29
+ # 构建wheel包
30
+ echo "🎡 构建wheel包..."
31
+ python setup.py bdist_wheel
32
+
33
+ # 显示构建结果
34
+ echo "✅ 构建完成!生成的包:"
35
+ ls -la dist/
36
+
37
+ # 验证包
38
+ echo "🔍 验证生成的包..."
39
+ python -m pip check dist/*.whl || echo "⚠️ 包验证有警告"
40
+
41
+ echo ""
42
+ echo "📝 安装说明:"
43
+ echo "# 从wheel文件安装:"
44
+ echo "pip install dist/hunyuanvideo_foley-1.0.0-py3-none-any.whl"
45
+ echo ""
46
+ echo "# 开发模式安装:"
47
+ echo "pip install -e ."
48
+ echo ""
49
+ echo "# 安装所有可选依赖:"
50
+ echo "pip install -e .[all]"
51
+ echo ""
52
+
53
+ echo "⚠️ 注意:某些依赖需要单独安装:"
54
+ echo "pip install git+https://github.com/descriptinc/audiotools"
55
+ echo "pip install git+https://github.com/huggingface/transformers@v4.49.0-SigLIP-2"
56
+
57
+ echo ""
58
+ echo "🎉 构建完成!查看 INSTALL.md 获取详细安装指南。"
HunyuanVideo-Foley/configs/hunyuanvideo-foley-xl.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ model_name: HunyuanVideo-Foley-XL
3
+ model_type: 1d
4
+ model_precision: bf16
5
+ model_kwargs:
6
+ depth_triple_blocks: 12
7
+ depth_single_blocks: 24
8
+ hidden_size: 1408
9
+ num_heads: 11
10
+ mlp_ratio: 4
11
+ mlp_act_type: "gelu_tanh"
12
+ qkv_bias: True
13
+ qk_norm: True
14
+ qk_norm_type: "rms"
15
+ attn_mode: "torch"
16
+ embedder_type: "default"
17
+ interleaved_audio_visual_rope: True
18
+ enable_learnable_empty_visual_feat: True
19
+ sync_modulation: False
20
+ add_sync_feat_to_audio: True
21
+ cross_attention: True
22
+ use_attention_mask: False
23
+ condition_projection: "linear"
24
+ sync_feat_dim: 768 # syncformer 768 dim
25
+ condition_dim: 768 # clap 768 text condition dim (clip-text)
26
+ clip_dim: 768 # siglip2 visual dim
27
+ audio_vae_latent_dim: 128
28
+ audio_frame_rate: 50
29
+ patch_size: 1
30
+ rope_dim_list: null
31
+ rope_theta: 10000
32
+ text_length: 77
33
+ clip_length: 64
34
+ sync_length: 192
35
+ depth_triple_ssl_encoder: null
36
+ depth_single_ssl_encoder: 8
37
+ use_repa_with_audiossl: True
38
+
39
+ diffusion_config:
40
+ denoise_type: "flow"
41
+ flow_path_type: "linear"
42
+ flow_predict_type: "velocity"
43
+ flow_reverse: True
44
+ flow_solver: "euler"
45
+ sample_flow_shift: 1.0
46
+ sample_use_flux_shift: False
47
+ flux_base_shift: 0.5
48
+ flux_max_shift: 1.15
HunyuanVideo-Foley/configs/hunyuanvideo-foley-xxl.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ model_name: HunyuanVideo-Foley-XXL
3
+ model_type: 1d
4
+ model_precision: bf16
5
+ model_kwargs:
6
+ depth_triple_blocks: 18
7
+ depth_single_blocks: 36
8
+ hidden_size: 1536
9
+ num_heads: 12
10
+ mlp_ratio: 4
11
+ mlp_act_type: "gelu_tanh"
12
+ qkv_bias: True
13
+ qk_norm: True
14
+ qk_norm_type: "rms"
15
+ attn_mode: "torch"
16
+ embedder_type: "default"
17
+ interleaved_audio_visual_rope: True
18
+ enable_learnable_empty_visual_feat: True
19
+ sync_modulation: False
20
+ add_sync_feat_to_audio: True
21
+ cross_attention: True
22
+ use_attention_mask: False
23
+ condition_projection: "linear"
24
+ sync_feat_dim: 768 # syncformer 768 dim
25
+ condition_dim: 768 # clap 768 text condition dim (clip-text)
26
+ clip_dim: 768 # siglip2 visual dim
27
+ audio_vae_latent_dim: 128
28
+ audio_frame_rate: 50
29
+ patch_size: 1
30
+ rope_dim_list: null
31
+ rope_theta: 10000
32
+ text_length: 77
33
+ clip_length: 64
34
+ sync_length: 192
35
+ depth_triple_ssl_encoder: null
36
+ depth_single_ssl_encoder: 8
37
+ use_repa_with_audiossl: True
38
+
39
+ diffusion_config:
40
+ denoise_type: "flow"
41
+ flow_path_type: "linear"
42
+ flow_predict_type: "velocity"
43
+ flow_reverse: True
44
+ flow_solver: "euler"
45
+ sample_flow_shift: 1.0
46
+ sample_use_flux_shift: False
47
+ flux_base_shift: 0.5
48
+ flux_max_shift: 1.15
HunyuanVideo-Foley/download_test_videos.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Download MoviegenAudioBenchSfx 10 videos
4
+ curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/MovieGenAudioBenchSfx.tar.gz
5
+ tar -xzvf MovieGenAudioBenchSfx.tar.gz -C ./assets
6
+ rm MovieGenAudioBenchSfx.tar.gz
7
+
8
+ # Download gradio example video
9
+ curl -O https://texttoaudio-train-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuanvideo-foley_demo/examples.tar.gz
10
+ tar -xvzf examples.tar.gz
11
+ rm examples.tar.gz
HunyuanVideo-Foley/gradio_app.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+ from loguru import logger
7
+ from typing import Optional, Tuple
8
+ import random
9
+ import numpy as np
10
+
11
+ from hunyuanvideo_foley.utils.model_utils import load_model
12
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
13
+ from hunyuanvideo_foley.utils.model_utils import denoise_process
14
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
15
+
16
+ # Global variables for model storage
17
+ model_dict = None
18
+ cfg = None
19
+ device = None
20
+
21
+ # need to modify the model path
22
+ MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
23
+ ENABLE_OFFLOAD = os.environ.get("ENABLE_OFFLOAD", "false").lower() in ("true", "1", "yes")
24
+ MODEL_SIZE = os.environ.get("MODEL_SIZE", "xxl") # default to xxl model
25
+ CONFIG_PATH = os.environ.get("CONFIG_PATH", "")
26
+
27
+ def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device:
28
+ """Setup computing device"""
29
+ if device_str == "auto":
30
+ if torch.cuda.is_available():
31
+ device = torch.device(f"cuda:{gpu_id}")
32
+ logger.info(f"Using CUDA device: {device}")
33
+ elif torch.backends.mps.is_available():
34
+ device = torch.device("mps")
35
+ logger.info("Using MPS device")
36
+ else:
37
+ device = torch.device("cpu")
38
+ logger.info("Using CPU device")
39
+ else:
40
+ if device_str == "cuda":
41
+ device = torch.device(f"cuda:{gpu_id}")
42
+ else:
43
+ device = torch.device(device_str)
44
+ logger.info(f"Using specified device: {device}")
45
+
46
+ return device
47
+
48
+ def auto_load_models() -> str:
49
+ """Automatically load preset models"""
50
+ global model_dict, cfg, device
51
+
52
+ try:
53
+ if not os.path.exists(MODEL_PATH):
54
+ return f"❌ Model directory not found: {MODEL_PATH}"
55
+
56
+ # Use GPU by default
57
+ device = setup_device("auto", 0)
58
+
59
+ # Auto-select config if not specified
60
+ config_path = CONFIG_PATH
61
+ if not config_path:
62
+ config_mapping = {
63
+ "xl": "configs/hunyuanvideo-foley-xl.yaml",
64
+ "xxl": "configs/hunyuanvideo-foley-xxl.yaml"
65
+ }
66
+ config_path = config_mapping.get(MODEL_SIZE, "configs/hunyuanvideo-foley-xxl.yaml")
67
+
68
+ # Load model
69
+ logger.info("Auto-loading model...")
70
+ logger.info(f"Model path: {MODEL_PATH}")
71
+ logger.info(f"Model size: {MODEL_SIZE}")
72
+ logger.info(f"Config path: {config_path}")
73
+ logger.info(f"Offload mode: {'enabled' if ENABLE_OFFLOAD else 'disabled'}")
74
+
75
+ model_dict, cfg = load_model(MODEL_PATH, config_path, device, enable_offload=ENABLE_OFFLOAD, model_size=MODEL_SIZE)
76
+
77
+ logger.info("✅ Model loaded successfully!")
78
+ return "✅ Model loaded successfully!"
79
+
80
+ except Exception as e:
81
+ logger.error(f"Model loading failed: {str(e)}")
82
+ return f"❌ Model loading failed: {str(e)}"
83
+
84
+ def infer_single_video(
85
+ video_file,
86
+ text_prompt: str,
87
+ neg_prompt: str = None,
88
+ guidance_scale: float = 4.5,
89
+ num_inference_steps: int = 50,
90
+ sample_nums: int = 1
91
+ ) -> Tuple[list, str]:
92
+ """Single video inference"""
93
+ global model_dict, cfg, device
94
+
95
+ if model_dict is None or cfg is None:
96
+ return [], "❌ Please load the model first!"
97
+
98
+ if video_file is None:
99
+ return [], "❌ Please upload a video file!"
100
+
101
+ # Allow empty text prompt, use empty string if no prompt provided
102
+ if text_prompt is None:
103
+ text_prompt = ""
104
+ text_prompt = text_prompt.strip()
105
+
106
+ try:
107
+ logger.info(f"Processing video: {video_file}")
108
+ logger.info(f"Text prompt: {text_prompt}")
109
+
110
+ # Feature processing
111
+ visual_feats, text_feats, audio_len_in_s = feature_process(
112
+ video_file,
113
+ text_prompt,
114
+ model_dict,
115
+ cfg,
116
+ neg_prompt=neg_prompt
117
+ )
118
+
119
+ # Denoising process to generate multiple audio samples
120
+ # Note: The model now generates sample_nums audio samples per inference
121
+ # The denoise_process function returns audio with shape [batch_size, channels, samples]
122
+ logger.info(f"Generating {sample_nums} audio samples...")
123
+ audio, sample_rate = denoise_process(
124
+ visual_feats,
125
+ text_feats,
126
+ audio_len_in_s,
127
+ model_dict,
128
+ cfg,
129
+ guidance_scale=guidance_scale,
130
+ num_inference_steps=num_inference_steps,
131
+ batch_size=sample_nums
132
+ )
133
+
134
+ # Create temporary files to save results
135
+ temp_dir = tempfile.mkdtemp()
136
+ video_outputs = []
137
+
138
+ # Process each generated audio sample
139
+ for i in range(sample_nums):
140
+ # Save audio file
141
+ audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
142
+ torchaudio.save(audio_output, audio[i], sample_rate)
143
+
144
+ # Merge video and audio
145
+ video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
146
+ merge_audio_video(audio_output, video_file, video_output)
147
+ video_outputs.append(video_output)
148
+
149
+ logger.info(f"Inference completed! Generated {sample_nums} samples.")
150
+ return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!"
151
+
152
+ except Exception as e:
153
+ logger.error(f"Inference failed: {str(e)}")
154
+ return [], f"❌ Inference failed: {str(e)}"
155
+
156
+ def update_video_outputs(video_list, status_msg):
157
+ """Update video outputs based on the number of generated samples"""
158
+ # Initialize all outputs as None
159
+ outputs = [None] * 6
160
+
161
+ # Set values based on generated videos
162
+ for i, video_path in enumerate(video_list[:6]): # Max 6 samples
163
+ outputs[i] = video_path
164
+
165
+ # Return all outputs plus status message
166
+ return tuple(outputs + [status_msg])
167
+
168
+ def create_gradio_interface():
169
+ """Create Gradio interface"""
170
+
171
+ # Custom CSS for beautiful interface with better contrast
172
+ css = """
173
+ .gradio-container {
174
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
175
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
176
+ min-height: 100vh;
177
+ }
178
+
179
+ .main-header {
180
+ text-align: center;
181
+ padding: 2rem 0;
182
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
183
+ border-radius: 20px;
184
+ margin-bottom: 2rem;
185
+ box-shadow: 0 8px 32px rgba(0,0,0,0.15);
186
+ }
187
+
188
+ .main-header h1 {
189
+ color: white;
190
+ font-size: 3rem;
191
+ font-weight: 700;
192
+ margin-bottom: 0.5rem;
193
+ text-shadow: 0 2px 10px rgba(0,0,0,0.3);
194
+ }
195
+
196
+ .main-header p {
197
+ color: rgba(255, 255, 255, 0.95);
198
+ font-size: 1.2rem;
199
+ font-weight: 300;
200
+ }
201
+
202
+ .status-card {
203
+ background: white;
204
+ border-radius: 15px;
205
+ padding: 1rem;
206
+ margin-bottom: 1.5rem;
207
+ border: 1px solid #e1e5e9;
208
+ box-shadow: 0 4px 20px rgba(0,0,0,0.08);
209
+ }
210
+
211
+ .status-card label {
212
+ color: #2d3748 !important;
213
+ font-weight: 600 !important;
214
+ }
215
+
216
+ .usage-guide h3 {
217
+ color: #2d3748 !important;
218
+ font-weight: 600 !important;
219
+ margin-bottom: 0.5rem !important;
220
+ }
221
+
222
+ .usage-guide p {
223
+ color: #4a5568 !important;
224
+ font-size: 1rem !important;
225
+ line-height: 1.6 !important;
226
+ margin: 0.5rem 0 !important;
227
+ }
228
+
229
+ .usage-guide strong {
230
+ color: #1a202c !important;
231
+ font-weight: 700 !important;
232
+ }
233
+
234
+ .usage-guide em {
235
+ color: #1a202c !important;
236
+ font-weight: 700 !important;
237
+ font-style: normal !important;
238
+ }
239
+
240
+ .main-interface {
241
+ margin-bottom: 2rem;
242
+ }
243
+
244
+ .input-section {
245
+ background: white;
246
+ border-radius: 20px;
247
+ padding: 2rem;
248
+ margin-right: 1rem;
249
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
250
+ border: 1px solid #e1e5e9;
251
+ }
252
+
253
+ .input-section h3 {
254
+ color: #2d3748 !important;
255
+ font-weight: 600 !important;
256
+ margin-bottom: 1rem !important;
257
+ }
258
+
259
+ .input-section label {
260
+ color: #4a5568 !important;
261
+ font-weight: 500 !important;
262
+ }
263
+
264
+ .output-section {
265
+ background: white;
266
+ border-radius: 20px;
267
+ padding: 2rem;
268
+ margin-left: 1rem;
269
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
270
+ border: 1px solid #e1e5e9;
271
+ }
272
+
273
+ .output-section h3 {
274
+ color: #2d3748 !important;
275
+ font-weight: 600 !important;
276
+ margin-bottom: 1rem !important;
277
+ }
278
+
279
+ .output-section label {
280
+ color: #4a5568 !important;
281
+ font-weight: 500 !important;
282
+ }
283
+
284
+ .examples-section h3 {
285
+ color: #2d3748 !important;
286
+ font-weight: 600 !important;
287
+ margin-bottom: 1.5rem !important;
288
+ }
289
+
290
+ .generate-btn {
291
+ background: linear-gradient(45deg, #667eea, #764ba2) !important;
292
+ border: none !important;
293
+ color: white !important;
294
+ font-weight: 600 !important;
295
+ font-size: 1.1rem !important;
296
+ padding: 12px 30px !important;
297
+ border-radius: 25px !important;
298
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
299
+ transition: all 0.3s ease !important;
300
+ }
301
+
302
+ .generate-btn:hover {
303
+ transform: translateY(-2px) !important;
304
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
305
+ }
306
+
307
+
308
+
309
+ .examples-section {
310
+ background: white;
311
+ border-radius: 20px;
312
+ padding: 2rem;
313
+ margin-top: 2rem;
314
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
315
+ border: 1px solid #e1e5e9;
316
+ }
317
+
318
+ .examples-section p {
319
+ color: #4a5568 !important;
320
+ margin-bottom: 1rem !important;
321
+ }
322
+
323
+ .example-row {
324
+ background: #f8fafc;
325
+ border: 1px solid #e2e8f0;
326
+ border-radius: 15px;
327
+ padding: 1.5rem;
328
+ margin: 1rem 0;
329
+ transition: all 0.3s ease;
330
+ align-items: center;
331
+ }
332
+
333
+ .example-row:hover {
334
+ border-color: #667eea;
335
+ transform: translateY(-2px);
336
+ box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
337
+ }
338
+
339
+ .example-row .markdown {
340
+ color: #2d3748 !important;
341
+ }
342
+
343
+ .example-row .markdown p {
344
+ color: #2d3748 !important;
345
+ margin: 0.5rem 0 !important;
346
+ line-height: 1.5 !important;
347
+ }
348
+
349
+ .example-row .markdown strong {
350
+ color: #1a202c !important;
351
+ font-weight: 600 !important;
352
+ }
353
+
354
+ /* Example grid layout styles */
355
+ .example-grid-row {
356
+ margin: 1rem 0;
357
+ gap: 1rem;
358
+ }
359
+
360
+ .example-item {
361
+ background: #f8fafc;
362
+ border: 1px solid #e2e8f0;
363
+ border-radius: 15px;
364
+ padding: 1rem;
365
+ transition: all 0.3s ease;
366
+ margin: 0.25rem;
367
+ max-width: 250px;
368
+ margin-left: auto;
369
+ margin-right: auto;
370
+ }
371
+
372
+ .example-item:hover {
373
+ border-color: #667eea;
374
+ transform: translateY(-2px);
375
+ box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
376
+ }
377
+
378
+ .example-caption {
379
+ margin: 0.5rem 0 !important;
380
+ min-height: 2.8rem !important;
381
+ display: flex !important;
382
+ align-items: flex-start !important;
383
+ }
384
+
385
+ .example-caption p {
386
+ color: #2d3748 !important;
387
+ font-size: 0.9rem !important;
388
+ line-height: 1.4 !important;
389
+ margin: 0.5rem 0 !important;
390
+ }
391
+
392
+ /* Multi-video gallery styles */
393
+ .additional-samples {
394
+ margin-top: 1rem;
395
+ gap: 0.5rem;
396
+ }
397
+
398
+ .additional-samples .gradio-video {
399
+ border-radius: 10px;
400
+ overflow: hidden;
401
+ }
402
+
403
+ /* Video gallery responsive layout */
404
+ .video-gallery {
405
+ display: grid;
406
+ gap: 1rem;
407
+ margin-top: 1rem;
408
+ }
409
+
410
+ .video-gallery.single {
411
+ grid-template-columns: 1fr;
412
+ }
413
+
414
+ .video-gallery.dual {
415
+ grid-template-columns: 1fr 1fr;
416
+ }
417
+
418
+ .video-gallery.multi {
419
+ grid-template-columns: repeat(2, 1fr);
420
+ grid-template-rows: auto auto auto;
421
+ }
422
+
423
+ .footer-text {
424
+ color: #718096 !important;
425
+ text-align: center;
426
+ padding: 2rem;
427
+ font-size: 0.9rem;
428
+ }
429
+
430
+ /* Video component styling for consistent size */
431
+ .input-section video,
432
+ .output-section video,
433
+ .example-row video {
434
+ width: 100% !important;
435
+ height: 300px !important;
436
+ object-fit: contain !important;
437
+ border-radius: 10px !important;
438
+ background-color: #000 !important;
439
+ }
440
+
441
+ .example-row video {
442
+ height: 150px !important;
443
+ }
444
+
445
+ /* Fix for additional samples video display */
446
+ .additional-samples video {
447
+ height: 150px !important;
448
+ object-fit: contain !important;
449
+ border-radius: 10px !important;
450
+ background-color: #000 !important;
451
+ }
452
+
453
+ .additional-samples .gradio-video {
454
+ border-radius: 10px !important;
455
+ overflow: hidden !important;
456
+ background-color: #000 !important;
457
+ }
458
+
459
+ .additional-samples .gradio-video > div {
460
+ background-color: #000 !important;
461
+ border-radius: 10px !important;
462
+ }
463
+
464
+ /* Video container styling */
465
+ .input-section .video-container,
466
+ .output-section .video-container,
467
+ .example-row .video-container {
468
+ background-color: #000 !important;
469
+ border-radius: 10px !important;
470
+ display: flex !important;
471
+ align-items: center !important;
472
+ justify-content: center !important;
473
+ overflow: hidden !important;
474
+ }
475
+
476
+ /* Ensure proper alignment */
477
+ .example-row {
478
+ display: flex !important;
479
+ align-items: stretch !important;
480
+ }
481
+
482
+ .example-row > div {
483
+ display: flex !important;
484
+ flex-direction: column !important;
485
+ justify-content: center !important;
486
+ }
487
+
488
+ /* Video wrapper for better control */
489
+ .video-wrapper {
490
+ position: relative !important;
491
+ width: 100% !important;
492
+ background: #000 !important;
493
+ border-radius: 10px !important;
494
+ overflow: hidden !important;
495
+ display: flex !important;
496
+ align-items: center !important;
497
+ justify-content: center !important;
498
+ }
499
+ """
500
+
501
+ with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app:
502
+
503
+ # Main header
504
+ with gr.Column(elem_classes=["main-header"]):
505
+ gr.HTML("""
506
+ <h1>🎵 HunyuanVideo-Foley</h1>
507
+ <p>Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions</p>
508
+ """)
509
+
510
+ # Usage Guide
511
+ with gr.Column(elem_classes=["status-card"]):
512
+ gr.Markdown("""
513
+ ### 📋 Quick Start Guide
514
+ **1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio
515
+
516
+ 💡 For quick start, you can load the prepared examples by clicking the button.
517
+ """, elem_classes=["usage-guide"])
518
+
519
+ # Main inference interface - Input and Results side by side
520
+ with gr.Row(elem_classes=["main-interface"]):
521
+ # Input section
522
+ with gr.Column(scale=1, elem_classes=["input-section"]):
523
+ gr.Markdown("### 📹 Video Input")
524
+
525
+ video_input = gr.Video(
526
+ label="Upload Video",
527
+ info="Supported formats: MP4, AVI, MOV, etc.",
528
+ height=300
529
+ )
530
+
531
+ text_input = gr.Textbox(
532
+ label="🎯 Audio Description (English)",
533
+ placeholder="A person walks on frozen ice",
534
+ lines=3,
535
+ info="Describe the audio you want to generate (optional)"
536
+ )
537
+
538
+ neg_prompt_input = gr.Textbox(
539
+ label="🚫 Negative Prompt",
540
+ placeholder="noisy, harsh",
541
+ lines=2,
542
+ info="Describe what you want to avoid in the generated audio (optional, default: 'noisy, harsh')"
543
+ )
544
+
545
+ with gr.Row():
546
+ guidance_scale = gr.Slider(
547
+ minimum=1.0,
548
+ maximum=10.0,
549
+ value=4.5,
550
+ step=0.1,
551
+ label="🎚️ CFG Scale",
552
+ )
553
+
554
+ inference_steps = gr.Slider(
555
+ minimum=10,
556
+ maximum=100,
557
+ value=50,
558
+ step=5,
559
+ label="⚡ Steps",
560
+ )
561
+
562
+ sample_nums = gr.Slider(
563
+ minimum=1,
564
+ maximum=6,
565
+ value=1,
566
+ step=1,
567
+ label="🎲 Sample Nums",
568
+ )
569
+
570
+ generate_btn = gr.Button(
571
+ "🎵 Generate Audio",
572
+ variant="primary",
573
+ elem_classes=["generate-btn"]
574
+ )
575
+
576
+ # Results section
577
+ with gr.Column(scale=1, elem_classes=["output-section"]):
578
+ gr.Markdown("### 🎥 Generated Results")
579
+
580
+ # Multi-video gallery for displaying multiple generated samples
581
+ with gr.Column():
582
+ # Primary video (Sample 1)
583
+ video_output_1 = gr.Video(
584
+ label="Sample 1",
585
+ height=250,
586
+ visible=True
587
+ )
588
+
589
+ # Additional videos (Samples 2-6) - initially hidden
590
+ with gr.Row(elem_classes=["additional-samples"]):
591
+ with gr.Column(scale=1):
592
+ video_output_2 = gr.Video(
593
+ label="Sample 2",
594
+ height=150,
595
+ visible=False
596
+ )
597
+ video_output_3 = gr.Video(
598
+ label="Sample 3",
599
+ height=150,
600
+ visible=False
601
+ )
602
+ with gr.Column(scale=1):
603
+ video_output_4 = gr.Video(
604
+ label="Sample 4",
605
+ height=150,
606
+ visible=False
607
+ )
608
+ video_output_5 = gr.Video(
609
+ label="Sample 5",
610
+ height=150,
611
+ visible=False
612
+ )
613
+
614
+ # Sample 6 - full width
615
+ video_output_6 = gr.Video(
616
+ label="Sample 6",
617
+ height=150,
618
+ visible=False
619
+ )
620
+
621
+ result_text = gr.Textbox(
622
+ label="Status",
623
+ interactive=False,
624
+ lines=2
625
+ )
626
+
627
+ # Examples section at the bottom
628
+ with gr.Column(elem_classes=["examples-section"]):
629
+ gr.Markdown("### 🌟 Examples")
630
+ gr.Markdown("Click on any example to load it into the interface above")
631
+
632
+ # Define your custom examples here - 8 examples total
633
+ examples_data = [
634
+ # Example 1
635
+ {
636
+ "caption": "A person walks on frozen ice",
637
+ "video_path": "examples/1_video.mp4",
638
+ "result_path": "examples/1_result.mp4"
639
+ },
640
+ # Example 2
641
+ {
642
+ "caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.",
643
+ "video_path": "examples/2_video.mp4",
644
+ "result_path": "examples/2_result.mp4"
645
+ },
646
+ # Example 3
647
+ {
648
+ "caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.",
649
+ "video_path": "examples/3_video.mp4",
650
+ "result_path": "examples/3_result.mp4"
651
+ },
652
+ # Example 4
653
+ {
654
+ "caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.",
655
+ "video_path": "examples/4_video.mp4",
656
+ "result_path": "examples/4_result.mp4"
657
+ },
658
+ # Example 5 - Add your new examples here
659
+ {
660
+ "caption": "snow crunching under the snowboard's edge.",
661
+ "video_path": "examples/5_video.mp4",
662
+ "result_path": "examples/5_result.mp4"
663
+ },
664
+ # Example 6
665
+ {
666
+ "caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.",
667
+ "video_path": "examples/6_video.mp4",
668
+ "result_path": "examples/6_result.mp4"
669
+ },
670
+ # Example 7
671
+ {
672
+ "caption": "humming of the scooter engine accelerates slowly.",
673
+ "video_path": "examples/7_video.mp4",
674
+ "result_path": "examples/7_result.mp4"
675
+ },
676
+ # Example 8
677
+ {
678
+ "caption": "splash of water and loud thud as person hits the surface.",
679
+ "video_path": "examples/8_video.mp4",
680
+ "result_path": "examples/8_result.mp4"
681
+ }
682
+ ]
683
+
684
+ # Create example grid - 4 examples per row, 2 rows total
685
+ example_buttons = []
686
+ for row in range(2): # 2 rows
687
+ with gr.Row(elem_classes=["example-grid-row"]):
688
+ for col in range(4): # 4 columns
689
+ idx = row * 4 + col
690
+ if idx < len(examples_data):
691
+ example = examples_data[idx]
692
+
693
+ with gr.Column(scale=1, elem_classes=["example-item"]):
694
+ # Video thumbnail
695
+ if os.path.exists(example['video_path']):
696
+ example_video = gr.Video(
697
+ value=example['video_path'],
698
+ label=f"Example {idx+1}",
699
+ interactive=False,
700
+ show_label=True,
701
+ height=180
702
+ )
703
+ else:
704
+ example_video = gr.HTML(f"""
705
+ <div style="background: #f0f0f0; padding: 15px; text-align: center; border-radius: 8px; height: 180px; display: flex; align-items: center; justify-content: center;">
706
+ <div>
707
+ <p style="color: #666; margin: 0; font-size: 12px;">📹 Video not found</p>
708
+ <small style="color: #999; font-size: 10px;">{example['video_path']}</small>
709
+ </div>
710
+ </div>
711
+ """)
712
+
713
+ # Caption (truncated for grid layout)
714
+ caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption']
715
+ gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"])
716
+
717
+ # Load button
718
+ example_btn = gr.Button(
719
+ f"Load Example {idx+1}",
720
+ variant="secondary",
721
+ size="sm"
722
+ )
723
+ example_buttons.append((example_btn, example))
724
+
725
+ # Event handlers
726
+ def process_inference(video_file, text_prompt, neg_prompt, guidance_scale, inference_steps, sample_nums):
727
+ # Generate videos
728
+ video_list, status_msg = infer_single_video(
729
+ video_file, text_prompt, neg_prompt, guidance_scale, inference_steps, int(sample_nums)
730
+ )
731
+ # Update outputs with proper visibility
732
+ return update_video_outputs(video_list, status_msg)
733
+
734
+ # Add dynamic visibility control based on sample_nums
735
+ def update_visibility(sample_nums):
736
+ sample_nums = int(sample_nums)
737
+ return [
738
+ gr.update(visible=True), # Sample 1 always visible
739
+ gr.update(visible=sample_nums >= 2), # Sample 2
740
+ gr.update(visible=sample_nums >= 3), # Sample 3
741
+ gr.update(visible=sample_nums >= 4), # Sample 4
742
+ gr.update(visible=sample_nums >= 5), # Sample 5
743
+ gr.update(visible=sample_nums >= 6), # Sample 6
744
+ ]
745
+
746
+ # Update visibility when sample_nums changes
747
+ sample_nums.change(
748
+ fn=update_visibility,
749
+ inputs=[sample_nums],
750
+ outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6]
751
+ )
752
+
753
+ generate_btn.click(
754
+ fn=process_inference,
755
+ inputs=[video_input, text_input, neg_prompt_input, guidance_scale, inference_steps, sample_nums],
756
+ outputs=[
757
+ video_output_1, # Sample 1 value
758
+ video_output_2, # Sample 2 value
759
+ video_output_3, # Sample 3 value
760
+ video_output_4, # Sample 4 value
761
+ video_output_5, # Sample 5 value
762
+ video_output_6, # Sample 6 value
763
+ result_text
764
+ ]
765
+ )
766
+
767
+ # Add click handlers for example buttons
768
+ for btn, example in example_buttons:
769
+ def create_example_handler(ex):
770
+ def handler():
771
+ # Check if files exist, if not, return placeholder message
772
+ if os.path.exists(ex['video_path']):
773
+ video_file = ex['video_path']
774
+ else:
775
+ video_file = None
776
+
777
+ if os.path.exists(ex['result_path']):
778
+ result_video = ex['result_path']
779
+ else:
780
+ result_video = None
781
+
782
+ status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..."
783
+ if not video_file:
784
+ status_msg += f"\n⚠️ Video file not found: {ex['video_path']}"
785
+ if not result_video:
786
+ status_msg += f"\n⚠️ Result video not found: {ex['result_path']}"
787
+
788
+ return video_file, ex['caption'], "noisy, harsh", result_video, status_msg
789
+ return handler
790
+
791
+ btn.click(
792
+ fn=create_example_handler(example),
793
+ outputs=[video_input, text_input, neg_prompt_input, video_output_1, result_text]
794
+ )
795
+
796
+ # Footer
797
+ gr.HTML("""
798
+ <div class="footer-text">
799
+ <p>🚀 Powered by HunyuanVideo-Foley | Generate high-quality audio from video and text descriptions</p>
800
+ </div>
801
+ """)
802
+
803
+ return app
804
+
805
+ def set_manual_seed(global_seed):
806
+ random.seed(global_seed)
807
+ np.random.seed(global_seed)
808
+ torch.manual_seed(global_seed)
809
+
810
+ if __name__ == "__main__":
811
+ set_manual_seed(1)
812
+ # Setup logging
813
+ logger.remove()
814
+ logger.add(lambda msg: print(msg, end=''), level="INFO")
815
+
816
+ # Auto-load model
817
+ logger.info("Starting application and loading model...")
818
+ model_load_result = auto_load_models()
819
+ logger.info(model_load_result)
820
+
821
+ # Create and launch Gradio app
822
+ app = create_gradio_interface()
823
+
824
+ # Log completion status
825
+ if "successfully" in model_load_result:
826
+ logger.info("Application ready, model loaded")
827
+
828
+ app.launch(
829
+ server_name="0.0.0.0",
830
+ server_port=8080,
831
+ share=False,
832
+ debug=False,
833
+ show_error=True
834
+ )
HunyuanVideo-Foley/hunyuanvideo_foley/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment
3
+ for High-Fidelity Foley Audio Generation
4
+
5
+ This package provides tools for generating high-quality Foley audio effects
6
+ from video content using multimodal diffusion models.
7
+ """
8
+
9
+ __version__ = "1.0.0"
10
+ __author__ = "Tencent Hunyuan Team"
11
+ __email__ = "hunyuan@tencent.com"
12
+
13
+ # Import main components for easy access
14
+ try:
15
+ from .utils.model_utils import load_model, denoise_process
16
+ from .utils.feature_utils import feature_process
17
+ from .utils.media_utils import merge_audio_video
18
+ from .utils.config_utils import AttributeDict
19
+
20
+ __all__ = [
21
+ "__version__",
22
+ "load_model",
23
+ "denoise_process",
24
+ "feature_process",
25
+ "merge_audio_video",
26
+ "AttributeDict"
27
+ ]
28
+ except ImportError:
29
+ # Handle missing dependencies gracefully during installation
30
+ __all__ = ["__version__"]
HunyuanVideo-Foley/hunyuanvideo_foley/cli.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Command Line Interface for HunyuanVideo-Foley
4
+
5
+ Provides command-line access to the main inference functionality.
6
+ """
7
+
8
+ import sys
9
+ import argparse
10
+ from pathlib import Path
11
+
12
+ def main():
13
+ """Main CLI entry point."""
14
+ parser = argparse.ArgumentParser(
15
+ description="HunyuanVideo-Foley: Generate Foley audio from video and text",
16
+ formatter_class=argparse.RawDescriptionHelpFormatter,
17
+ epilog="""
18
+ Examples:
19
+ # Single video generation
20
+ hunyuanvideo-foley --model_path ./models --single_video video.mp4 --single_prompt "footsteps on gravel"
21
+
22
+ # Batch processing
23
+ hunyuanvideo-foley --model_path ./models --csv_path batch.csv --output_dir ./outputs
24
+
25
+ # Start Gradio interface
26
+ hunyuanvideo-foley --gradio --model_path ./models
27
+ """
28
+ )
29
+
30
+ parser.add_argument("--model_path", type=str, required=True,
31
+ help="Path to the pretrained model directory")
32
+ parser.add_argument("--config_path", type=str,
33
+ default="configs/hunyuanvideo-foley-xxl.yaml",
34
+ help="Path to the model configuration file")
35
+
36
+ # Input options
37
+ group_input = parser.add_mutually_exclusive_group(required=True)
38
+ group_input.add_argument("--single_video", type=str,
39
+ help="Path to single video file for processing")
40
+ group_input.add_argument("--csv_path", type=str,
41
+ help="Path to CSV file with video paths and prompts")
42
+ group_input.add_argument("--gradio", action="store_true",
43
+ help="Launch Gradio web interface")
44
+
45
+ # Generation options
46
+ parser.add_argument("--single_prompt", type=str,
47
+ help="Text prompt for single video (required with --single_video)")
48
+ parser.add_argument("--output_dir", type=str, default="./outputs",
49
+ help="Output directory for generated audio files")
50
+ parser.add_argument("--guidance_scale", type=float, default=4.5,
51
+ help="Guidance scale for generation (default: 4.5)")
52
+ parser.add_argument("--num_inference_steps", type=int, default=50,
53
+ help="Number of inference steps (default: 50)")
54
+ parser.add_argument("--neg_prompt", type=str,
55
+ help="Negative prompt to avoid certain audio characteristics")
56
+
57
+ # System options
58
+ parser.add_argument("--device", type=str, default="auto",
59
+ choices=["auto", "cpu", "cuda"],
60
+ help="Device to use for inference")
61
+ parser.add_argument("--gpu_id", type=int, default=0,
62
+ help="GPU ID to use (default: 0)")
63
+ parser.add_argument("--seed", type=int, default=42,
64
+ help="Random seed for reproducible generation")
65
+
66
+ args = parser.parse_args()
67
+
68
+ # Validate arguments
69
+ if args.single_video and not args.single_prompt:
70
+ parser.error("--single_prompt is required when using --single_video")
71
+
72
+ # Import here to avoid import errors if dependencies are missing
73
+ try:
74
+ if args.gradio:
75
+ _launch_gradio(args)
76
+ elif args.single_video:
77
+ _process_single_video(args)
78
+ elif args.csv_path:
79
+ _process_batch(args)
80
+ except ImportError as e:
81
+ print(f"Error: Missing required dependencies. Please install with: pip install hunyuanvideo-foley[all]")
82
+ print(f"Import error: {e}")
83
+ sys.exit(1)
84
+ except Exception as e:
85
+ print(f"Error: {e}")
86
+ sys.exit(1)
87
+
88
+ def _launch_gradio(args):
89
+ """Launch Gradio web interface."""
90
+ import os
91
+ os.environ["HIFI_FOLEY_MODEL_PATH"] = args.model_path
92
+
93
+ # Import and launch gradio app
94
+ import subprocess
95
+ gradio_script = Path(__file__).parent.parent / "gradio_app.py"
96
+ subprocess.run([sys.executable, str(gradio_script)])
97
+
98
+ def _process_single_video(args):
99
+ """Process a single video file."""
100
+ from . import infer
101
+
102
+ print(f"Processing video: {args.single_video}")
103
+ print(f"Prompt: {args.single_prompt}")
104
+
105
+ # This would need to be implemented to match the actual infer.py interface
106
+ # For now, redirect to the original script
107
+ import subprocess
108
+ cmd = [
109
+ sys.executable, "infer.py",
110
+ "--model_path", args.model_path,
111
+ "--config_path", args.config_path,
112
+ "--single_video", args.single_video,
113
+ "--single_prompt", args.single_prompt,
114
+ "--output_dir", args.output_dir,
115
+ "--guidance_scale", str(args.guidance_scale),
116
+ "--num_inference_steps", str(args.num_inference_steps)
117
+ ]
118
+ if args.neg_prompt:
119
+ cmd.extend(["--neg_prompt", args.neg_prompt])
120
+
121
+ subprocess.run(cmd)
122
+
123
+ def _process_batch(args):
124
+ """Process a batch of videos from CSV."""
125
+ import subprocess
126
+ cmd = [
127
+ sys.executable, "infer.py",
128
+ "--model_path", args.model_path,
129
+ "--config_path", args.config_path,
130
+ "--csv_path", args.csv_path,
131
+ "--output_dir", args.output_dir,
132
+ "--guidance_scale", str(args.guidance_scale),
133
+ "--num_inference_steps", str(args.num_inference_steps)
134
+ ]
135
+ if args.neg_prompt:
136
+ cmd.extend(["--neg_prompt", args.neg_prompt])
137
+
138
+ subprocess.run(cmd)
139
+
140
+ if __name__ == "__main__":
141
+ main()
HunyuanVideo-Foley/hunyuanvideo_foley/constants.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants used throughout the HunyuanVideo-Foley project."""
2
+
3
+ from typing import Dict, List
4
+
5
+ # Model configuration
6
+ DEFAULT_AUDIO_SAMPLE_RATE = 48000
7
+ DEFAULT_VIDEO_FPS = 25
8
+ DEFAULT_AUDIO_CHANNELS = 2
9
+
10
+ # Video processing
11
+ MAX_VIDEO_DURATION_SECONDS = 15.0
12
+ MIN_VIDEO_DURATION_SECONDS = 1.0
13
+
14
+ # Audio processing
15
+ AUDIO_VAE_LATENT_DIM = 128
16
+ AUDIO_FRAME_RATE = 75 # frames per second in latent space
17
+
18
+ # Visual features
19
+ FPS_VISUAL: Dict[str, int] = {
20
+ "siglip2": 8,
21
+ "synchformer": 25
22
+ }
23
+
24
+ # Model paths (can be overridden by environment variables)
25
+ DEFAULT_MODEL_PATH = "./pretrained_models/"
26
+ DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27
+
28
+ # Inference parameters
29
+ DEFAULT_GUIDANCE_SCALE = 4.5
30
+ DEFAULT_NUM_INFERENCE_STEPS = 50
31
+ MIN_GUIDANCE_SCALE = 1.0
32
+ MAX_GUIDANCE_SCALE = 10.0
33
+ MIN_INFERENCE_STEPS = 10
34
+ MAX_INFERENCE_STEPS = 100
35
+
36
+ # Text processing
37
+ MAX_TEXT_LENGTH = 100
38
+ DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
39
+
40
+ # File extensions
41
+ SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
42
+ SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
43
+
44
+ # Quality settings
45
+ AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
46
+ "high": ["-b:a", "192k"],
47
+ "medium": ["-b:a", "128k"],
48
+ "low": ["-b:a", "96k"]
49
+ }
50
+
51
+ # Error messages
52
+ ERROR_MESSAGES: Dict[str, str] = {
53
+ "model_not_loaded": "Model is not loaded. Please load the model first.",
54
+ "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
55
+ "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
56
+ "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
57
+ }
HunyuanVideo-Foley/hunyuanvideo_foley/models/__init__.py ADDED
File without changes
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from .utils import download
6
+ from .utils.decode import decode
7
+ from .utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/base.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ audio_signal = audio_signal.to_mono()
165
+ original_sr = audio_signal.sample_rate
166
+
167
+ resample_fn = audio_signal.resample
168
+ loudness_fn = audio_signal.loudness
169
+
170
+ # If audio is > 10 minutes long, use the ffmpeg versions
171
+ if audio_signal.signal_duration >= 10 * 60 * 60:
172
+ resample_fn = audio_signal.ffmpeg_resample
173
+ loudness_fn = audio_signal.ffmpeg_loudness
174
+
175
+ original_length = audio_signal.signal_length
176
+ resample_fn(self.sample_rate)
177
+ input_db = loudness_fn()
178
+
179
+ if normalize_db is not None:
180
+ audio_signal.normalize(normalize_db)
181
+ audio_signal.ensure_max_of_audio()
182
+
183
+ nb, nac, nt = audio_signal.audio_data.shape
184
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
185
+ win_duration = (
186
+ audio_signal.signal_duration if win_duration is None else win_duration
187
+ )
188
+
189
+ if audio_signal.signal_duration <= win_duration:
190
+ # Unchunked compression (used if signal length < win duration)
191
+ self.padding = True
192
+ n_samples = nt
193
+ hop = nt
194
+ else:
195
+ # Chunked inference
196
+ self.padding = False
197
+ # Zero-pad signal on either side by the delay
198
+ audio_signal.zero_pad(self.delay, self.delay)
199
+ n_samples = int(win_duration * self.sample_rate)
200
+ # Round n_samples to nearest hop length multiple
201
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
202
+ hop = self.get_output_length(n_samples)
203
+
204
+ codes = []
205
+ range_fn = range if not verbose else tqdm.trange
206
+
207
+ for i in range_fn(0, nt, hop):
208
+ x = audio_signal[..., i : i + n_samples]
209
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
210
+
211
+ audio_data = x.audio_data.to(self.device)
212
+ audio_data = self.preprocess(audio_data, self.sample_rate)
213
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
214
+ codes.append(c.to(original_device))
215
+ chunk_length = c.shape[-1]
216
+
217
+ codes = torch.cat(codes, dim=-1)
218
+
219
+ dac_file = DACFile(
220
+ codes=codes,
221
+ chunk_length=chunk_length,
222
+ original_length=original_length,
223
+ input_db=input_db,
224
+ channels=nac,
225
+ sample_rate=original_sr,
226
+ padding=self.padding,
227
+ dac_version=SUPPORTED_VERSIONS[-1],
228
+ )
229
+
230
+ if n_quantizers is not None:
231
+ codes = codes[:, :n_quantizers, :]
232
+
233
+ self.padding = original_padding
234
+ return dac_file
235
+
236
+ @torch.no_grad()
237
+ def decompress(
238
+ self,
239
+ obj: Union[str, Path, DACFile],
240
+ verbose: bool = False,
241
+ ) -> AudioSignal:
242
+ """Reconstruct audio from a given .dac file
243
+
244
+ Parameters
245
+ ----------
246
+ obj : Union[str, Path, DACFile]
247
+ .dac file location or corresponding DACFile object.
248
+ verbose : bool, optional
249
+ Prints progress if True, by default False
250
+
251
+ Returns
252
+ -------
253
+ AudioSignal
254
+ Object with the reconstructed audio
255
+ """
256
+ self.eval()
257
+ if isinstance(obj, (str, Path)):
258
+ obj = DACFile.load(obj)
259
+
260
+ original_padding = self.padding
261
+ self.padding = obj.padding
262
+
263
+ range_fn = range if not verbose else tqdm.trange
264
+ codes = obj.codes
265
+ original_device = codes.device
266
+ chunk_length = obj.chunk_length
267
+ recons = []
268
+
269
+ for i in range_fn(0, codes.shape[-1], chunk_length):
270
+ c = codes[..., i : i + chunk_length].to(self.device)
271
+ z = self.quantizer.from_codes(c)[0]
272
+ r = self.decode(z)
273
+ recons.append(r.to(original_device))
274
+
275
+ recons = torch.cat(recons, dim=-1)
276
+ recons = AudioSignal(recons, self.sample_rate)
277
+
278
+ resample_fn = recons.resample
279
+ loudness_fn = recons.loudness
280
+
281
+ # If audio is > 10 minutes long, use the ffmpeg versions
282
+ if recons.signal_duration >= 10 * 60 * 60:
283
+ resample_fn = recons.ffmpeg_resample
284
+ loudness_fn = recons.ffmpeg_loudness
285
+
286
+ if obj.input_db is not None:
287
+ recons.normalize(obj.input_db)
288
+
289
+ resample_fn(obj.sample_rate)
290
+
291
+ if obj.original_length is not None:
292
+ recons = recons[..., : obj.original_length]
293
+ loudness_fn()
294
+ recons.audio_data = recons.audio_data.reshape(
295
+ -1, obj.channels, obj.original_length
296
+ )
297
+ else:
298
+ loudness_fn()
299
+
300
+ self.padding = original_padding
301
+ return recons
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/dac.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from ..nn.layers import Snake1d
13
+ from ..nn.layers import WNConv1d
14
+ from ..nn.layers import WNConvTranspose1d
15
+ from ..nn.quantize import ResidualVectorQuantize
16
+ from ..nn.vae_utils import DiagonalGaussianDistribution
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1):
27
+ super().__init__()
28
+ pad = ((7 - 1) * dilation) // 2
29
+ self.block = nn.Sequential(
30
+ Snake1d(dim),
31
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
32
+ Snake1d(dim),
33
+ WNConv1d(dim, dim, kernel_size=1),
34
+ )
35
+
36
+ def forward(self, x):
37
+ y = self.block(x)
38
+ pad = (x.shape[-1] - y.shape[-1]) // 2
39
+ if pad > 0:
40
+ x = x[..., pad:-pad]
41
+ return x + y
42
+
43
+
44
+ class EncoderBlock(nn.Module):
45
+ def __init__(self, dim: int = 16, stride: int = 1):
46
+ super().__init__()
47
+ self.block = nn.Sequential(
48
+ ResidualUnit(dim // 2, dilation=1),
49
+ ResidualUnit(dim // 2, dilation=3),
50
+ ResidualUnit(dim // 2, dilation=9),
51
+ Snake1d(dim // 2),
52
+ WNConv1d(
53
+ dim // 2,
54
+ dim,
55
+ kernel_size=2 * stride,
56
+ stride=stride,
57
+ padding=math.ceil(stride / 2),
58
+ ),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.block(x)
63
+
64
+
65
+ class Encoder(nn.Module):
66
+ def __init__(
67
+ self,
68
+ d_model: int = 64,
69
+ strides: list = [2, 4, 8, 8],
70
+ d_latent: int = 64,
71
+ ):
72
+ super().__init__()
73
+ # Create first convolution
74
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
75
+
76
+ # Create EncoderBlocks that double channels as they downsample by `stride`
77
+ for stride in strides:
78
+ d_model *= 2
79
+ self.block += [EncoderBlock(d_model, stride=stride)]
80
+
81
+ # Create last convolution
82
+ self.block += [
83
+ Snake1d(d_model),
84
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
85
+ ]
86
+
87
+ # Wrap black into nn.Sequential
88
+ self.block = nn.Sequential(*self.block)
89
+ self.enc_dim = d_model
90
+
91
+ def forward(self, x):
92
+ return self.block(x)
93
+
94
+
95
+ class DecoderBlock(nn.Module):
96
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
97
+ super().__init__()
98
+ self.block = nn.Sequential(
99
+ Snake1d(input_dim),
100
+ WNConvTranspose1d(
101
+ input_dim,
102
+ output_dim,
103
+ kernel_size=2 * stride,
104
+ stride=stride,
105
+ padding=math.ceil(stride / 2),
106
+ output_padding=stride % 2,
107
+ ),
108
+ ResidualUnit(output_dim, dilation=1),
109
+ ResidualUnit(output_dim, dilation=3),
110
+ ResidualUnit(output_dim, dilation=9),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.block(x)
115
+
116
+
117
+ class Decoder(nn.Module):
118
+ def __init__(
119
+ self,
120
+ input_channel,
121
+ channels,
122
+ rates,
123
+ d_out: int = 1,
124
+ ):
125
+ super().__init__()
126
+
127
+ # Add first conv layer
128
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
129
+
130
+ # Add upsampling + MRF blocks
131
+ for i, stride in enumerate(rates):
132
+ input_dim = channels // 2**i
133
+ output_dim = channels // 2 ** (i + 1)
134
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
135
+
136
+ # Add final conv layer
137
+ layers += [
138
+ Snake1d(output_dim),
139
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
140
+ nn.Tanh(),
141
+ ]
142
+
143
+ self.model = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ return self.model(x)
147
+
148
+
149
+ class DAC(BaseModel, CodecMixin):
150
+ def __init__(
151
+ self,
152
+ encoder_dim: int = 64,
153
+ encoder_rates: List[int] = [2, 4, 8, 8],
154
+ latent_dim: int = None,
155
+ decoder_dim: int = 1536,
156
+ decoder_rates: List[int] = [8, 8, 4, 2],
157
+ n_codebooks: int = 9,
158
+ codebook_size: int = 1024,
159
+ codebook_dim: Union[int, list] = 8,
160
+ quantizer_dropout: bool = False,
161
+ sample_rate: int = 44100,
162
+ continuous: bool = False,
163
+ ):
164
+ super().__init__()
165
+
166
+ self.encoder_dim = encoder_dim
167
+ self.encoder_rates = encoder_rates
168
+ self.decoder_dim = decoder_dim
169
+ self.decoder_rates = decoder_rates
170
+ self.sample_rate = sample_rate
171
+ self.continuous = continuous
172
+
173
+ if latent_dim is None:
174
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
175
+
176
+ self.latent_dim = latent_dim
177
+
178
+ self.hop_length = np.prod(encoder_rates)
179
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
180
+
181
+ if not continuous:
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+ else:
193
+ self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
194
+ self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
195
+
196
+ self.decoder = Decoder(
197
+ latent_dim,
198
+ decoder_dim,
199
+ decoder_rates,
200
+ )
201
+ self.sample_rate = sample_rate
202
+ self.apply(init_weights)
203
+
204
+ self.delay = self.get_delay()
205
+
206
+ @property
207
+ def dtype(self):
208
+ """Get the dtype of the model parameters."""
209
+ # Return the dtype of the first parameter found
210
+ for param in self.parameters():
211
+ return param.dtype
212
+ return torch.float32 # fallback
213
+
214
+ @property
215
+ def device(self):
216
+ """Get the device of the model parameters."""
217
+ # Return the device of the first parameter found
218
+ for param in self.parameters():
219
+ return param.device
220
+ return torch.device('cpu') # fallback
221
+
222
+ def preprocess(self, audio_data, sample_rate):
223
+ if sample_rate is None:
224
+ sample_rate = self.sample_rate
225
+ assert sample_rate == self.sample_rate
226
+
227
+ length = audio_data.shape[-1]
228
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
229
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
230
+
231
+ return audio_data
232
+
233
+ def encode(
234
+ self,
235
+ audio_data: torch.Tensor,
236
+ n_quantizers: int = None,
237
+ ):
238
+ """Encode given audio data and return quantized latent codes
239
+
240
+ Parameters
241
+ ----------
242
+ audio_data : Tensor[B x 1 x T]
243
+ Audio data to encode
244
+ n_quantizers : int, optional
245
+ Number of quantizers to use, by default None
246
+ If None, all quantizers are used.
247
+
248
+ Returns
249
+ -------
250
+ dict
251
+ A dictionary with the following keys:
252
+ "z" : Tensor[B x D x T]
253
+ Quantized continuous representation of input
254
+ "codes" : Tensor[B x N x T]
255
+ Codebook indices for each codebook
256
+ (quantized discrete representation of input)
257
+ "latents" : Tensor[B x N*D x T]
258
+ Projected latents (continuous representation of input before quantization)
259
+ "vq/commitment_loss" : Tensor[1]
260
+ Commitment loss to train encoder to predict vectors closer to codebook
261
+ entries
262
+ "vq/codebook_loss" : Tensor[1]
263
+ Codebook loss to update the codebook
264
+ "length" : int
265
+ Number of samples in input audio
266
+ """
267
+ z = self.encoder(audio_data) # [B x D x T]
268
+ if not self.continuous:
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
270
+ else:
271
+ z = self.quant_conv(z) # [B x 2D x T]
272
+ z = DiagonalGaussianDistribution(z)
273
+ codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
274
+
275
+ return z, codes, latents, commitment_loss, codebook_loss
276
+
277
+ def decode(self, z: torch.Tensor):
278
+ """Decode given latent codes and return audio data
279
+
280
+ Parameters
281
+ ----------
282
+ z : Tensor[B x D x T]
283
+ Quantized continuous representation of input
284
+ length : int, optional
285
+ Number of samples in output audio, by default None
286
+
287
+ Returns
288
+ -------
289
+ dict
290
+ A dictionary with the following keys:
291
+ "audio" : Tensor[B x 1 x length]
292
+ Decoded audio data.
293
+ """
294
+ if not self.continuous:
295
+ audio = self.decoder(z)
296
+ else:
297
+ z = self.post_quant_conv(z)
298
+ audio = self.decoder(z)
299
+
300
+ return audio
301
+
302
+ def forward(
303
+ self,
304
+ audio_data: torch.Tensor,
305
+ sample_rate: int = None,
306
+ n_quantizers: int = None,
307
+ ):
308
+ """Model forward pass
309
+
310
+ Parameters
311
+ ----------
312
+ audio_data : Tensor[B x 1 x T]
313
+ Audio data to encode
314
+ sample_rate : int, optional
315
+ Sample rate of audio data in Hz, by default None
316
+ If None, defaults to `self.sample_rate`
317
+ n_quantizers : int, optional
318
+ Number of quantizers to use, by default None.
319
+ If None, all quantizers are used.
320
+
321
+ Returns
322
+ -------
323
+ dict
324
+ A dictionary with the following keys:
325
+ "z" : Tensor[B x D x T]
326
+ Quantized continuous representation of input
327
+ "codes" : Tensor[B x N x T]
328
+ Codebook indices for each codebook
329
+ (quantized discrete representation of input)
330
+ "latents" : Tensor[B x N*D x T]
331
+ Projected latents (continuous representation of input before quantization)
332
+ "vq/commitment_loss" : Tensor[1]
333
+ Commitment loss to train encoder to predict vectors closer to codebook
334
+ entries
335
+ "vq/codebook_loss" : Tensor[1]
336
+ Codebook loss to update the codebook
337
+ "length" : int
338
+ Number of samples in input audio
339
+ "audio" : Tensor[B x 1 x length]
340
+ Decoded audio data.
341
+ """
342
+ length = audio_data.shape[-1]
343
+ audio_data = self.preprocess(audio_data, sample_rate)
344
+ if not self.continuous:
345
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
346
+
347
+ x = self.decode(z)
348
+ return {
349
+ "audio": x[..., :length],
350
+ "z": z,
351
+ "codes": codes,
352
+ "latents": latents,
353
+ "vq/commitment_loss": commitment_loss,
354
+ "vq/codebook_loss": codebook_loss,
355
+ }
356
+ else:
357
+ posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
358
+ z = posterior.sample()
359
+ x = self.decode(z)
360
+
361
+ kl_loss = posterior.kl()
362
+ kl_loss = kl_loss.mean()
363
+
364
+ return {
365
+ "audio": x[..., :length],
366
+ "z": z,
367
+ "kl_loss": kl_loss,
368
+ }
369
+
370
+
371
+ if __name__ == "__main__":
372
+ import numpy as np
373
+ from functools import partial
374
+
375
+ model = DAC().to("cpu")
376
+
377
+ for n, m in model.named_modules():
378
+ o = m.extra_repr()
379
+ p = sum([np.prod(p.size()) for p in m.parameters()])
380
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
381
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
382
+ print(model)
383
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
384
+
385
+ length = 88200 * 2
386
+ x = torch.randn(1, 1, length).to(model.device)
387
+ x.requires_grad_(True)
388
+ x.retain_grad()
389
+
390
+ # Make a forward pass
391
+ out = model(x)["audio"]
392
+ print("Input shape:", x.shape)
393
+ print("Output shape:", out.shape)
394
+
395
+ # Create gradient variable
396
+ grad = torch.zeros_like(out)
397
+ grad[:, :, grad.shape[-1] // 2] = 1
398
+
399
+ # Make a backward pass
400
+ out.backward(grad)
401
+
402
+ # Check non-zero values
403
+ gradmap = x.grad.squeeze(0)
404
+ gradmap = (gradmap != 0).sum(0) # sum across features
405
+ rf = (gradmap != 0).sum()
406
+
407
+ print(f"Receptive field: {rf.item()}")
408
+
409
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
410
+ model.decompress(model.compress(x, verbose=True), verbose=True)
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(ml.BaseModel):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from .layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.0])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.mean(
45
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2],
47
+ )
48
+ else:
49
+ return 0.5 * torch.mean(
50
+ torch.pow(self.mean - other.mean, 2) / other.var
51
+ + self.var / other.var
52
+ - 1.0
53
+ - self.logvar
54
+ + other.logvar,
55
+ dim=[1, 2],
56
+ )
57
+
58
+ def nll(self, sample, dims=[1, 2]):
59
+ if self.deterministic:
60
+ return torch.Tensor([0.0])
61
+ logtwopi = np.log(2.0 * np.pi)
62
+ return 0.5 * torch.sum(
63
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64
+ dim=dims,
65
+ )
66
+
67
+ def mode(self):
68
+ return self.mean
69
+
70
+
71
+ def normal_kl(mean1, logvar1, mean2, logvar2):
72
+ """
73
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
74
+ Compute the KL divergence between two gaussians.
75
+ Shapes are automatically broadcasted, so batches can be compared to
76
+ scalars, among other use cases.
77
+ """
78
+ tensor = None
79
+ for obj in (mean1, logvar1, mean2, logvar2):
80
+ if isinstance(obj, torch.Tensor):
81
+ tensor = obj
82
+ break
83
+ assert tensor is not None, "at least one argument must be a Tensor"
84
+
85
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
86
+ # Tensors, but it does not work for torch.exp().
87
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
88
+
89
+ return 0.5 * (
90
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
91
+ )
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ from ..model import DAC
7
+ Accelerator = ml.Accelerator
8
+
9
+ __MODEL_LATEST_TAGS__ = {
10
+ ("44khz", "8kbps"): "0.0.1",
11
+ ("24khz", "8kbps"): "0.0.4",
12
+ ("16khz", "8kbps"): "0.0.5",
13
+ ("44khz", "16kbps"): "1.0.0",
14
+ }
15
+
16
+ __MODEL_URLS__ = {
17
+ (
18
+ "44khz",
19
+ "0.0.1",
20
+ "8kbps",
21
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
22
+ (
23
+ "24khz",
24
+ "0.0.4",
25
+ "8kbps",
26
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
27
+ (
28
+ "16khz",
29
+ "0.0.5",
30
+ "8kbps",
31
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
32
+ (
33
+ "44khz",
34
+ "1.0.0",
35
+ "16kbps",
36
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
37
+ }
38
+
39
+
40
+ @argbind.bind(group="download", positional=True, without_prefix=True)
41
+ def download(
42
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
43
+ ):
44
+ """
45
+ Function that downloads the weights file from URL if a local cache is not found.
46
+
47
+ Parameters
48
+ ----------
49
+ model_type : str
50
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
51
+ model_bitrate: str
52
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
53
+ Only 44khz model supports 16kbps.
54
+ tag : str
55
+ The tag of the model to download. Defaults to "latest".
56
+
57
+ Returns
58
+ -------
59
+ Path
60
+ Directory path required to load model via audiotools.
61
+ """
62
+ model_type = model_type.lower()
63
+ tag = tag.lower()
64
+
65
+ assert model_type in [
66
+ "44khz",
67
+ "24khz",
68
+ "16khz",
69
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
70
+
71
+ assert model_bitrate in [
72
+ "8kbps",
73
+ "16kbps",
74
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
75
+
76
+ if tag == "latest":
77
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
78
+
79
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
80
+
81
+ if download_link is None:
82
+ raise ValueError(
83
+ f"Could not find model with tag {tag} and model type {model_type}"
84
+ )
85
+
86
+ local_path = (
87
+ Path.home()
88
+ / ".cache"
89
+ / "descript"
90
+ / "dac"
91
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
92
+ )
93
+ if not local_path.exists():
94
+ local_path.parent.mkdir(parents=True, exist_ok=True)
95
+
96
+ # Download the model
97
+ import requests
98
+
99
+ response = requests.get(download_link)
100
+
101
+ if response.status_code != 200:
102
+ raise ValueError(
103
+ f"Could not download model. Received response code {response.status_code}"
104
+ )
105
+ local_path.write_bytes(response.content)
106
+
107
+ return local_path
108
+
109
+
110
+ def load_model(
111
+ model_type: str = "44khz",
112
+ model_bitrate: str = "8kbps",
113
+ tag: str = "latest",
114
+ load_path: str = None,
115
+ ):
116
+ if not load_path:
117
+ load_path = download(
118
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
119
+ )
120
+ generator = DAC.load(load_path)
121
+ return generator
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from ..model import DACFile
11
+ from . import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
HunyuanVideo-Foley/hunyuanvideo_foley/models/dac_vae/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from . import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
HunyuanVideo-Foley/hunyuanvideo_foley/models/hifi_foley.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Union, Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+ from diffusers.models import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+ from .nn.activation_layers import SwiGLU, get_activation_layer
12
+ from .nn.attn_layers import apply_rotary_emb, attention
13
+ from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
14
+ from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
15
+ from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
16
+ from .nn.norm_layers import get_norm_layer
17
+ from .nn.posemb_layers import get_nd_rotary_pos_embed
18
+
19
+ def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
20
+ # [B, N1, H, C] & [B, N2, H, C]
21
+ B, N1, H, C = x1.shape
22
+ B, N2, H, C = x2.shape
23
+ assert x1.ndim == x2.ndim == 4
24
+
25
+ if N1 != N2:
26
+ x2 = x2.view(B, N2, -1).transpose(1, 2)
27
+ x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
28
+ x2 = x2.transpose(1, 2).view(B, N1, H, C)
29
+ x = torch.stack((x1, x2), dim=2)
30
+ x = x.reshape(B, N1 * 2, H, C)
31
+ return x
32
+
33
+ def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
34
+ B, N, H, C = x.shape
35
+ assert N % 2 == 0 and N // 2 == len1
36
+
37
+ x = x.reshape(B, -1, 2, H, C)
38
+ x1 = x[:, :, 0]
39
+ x2 = x[:, :, 1]
40
+ if x2.shape[1] != len2:
41
+ x2 = x2.view(B, len1, H * C).transpose(1, 2)
42
+ x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
43
+ x2 = x2.transpose(1, 2).view(B, len2, H, C)
44
+ return x1, x2
45
+
46
+ class TwoStreamCABlock(nn.Module):
47
+ def __init__(
48
+ self,
49
+ hidden_size: int,
50
+ num_heads: int,
51
+ mlp_ratio: float,
52
+ mlp_act_type: str = "gelu_tanh",
53
+ qk_norm: bool = True,
54
+ qk_norm_type: str = "rms",
55
+ qkv_bias: bool = False,
56
+ attn_mode: str = "torch",
57
+ reverse: bool = False,
58
+ interleaved_audio_visual_rope: bool = False,
59
+ dtype: Optional[torch.dtype] = None,
60
+ device: Optional[torch.device] = None,
61
+ ):
62
+ factory_kwargs = {"device": device, "dtype": dtype}
63
+ super().__init__()
64
+
65
+ self.deterministic = False
66
+ self.reverse = reverse
67
+ self.attn_mode = attn_mode
68
+ self.num_heads = num_heads
69
+ self.hidden_size = hidden_size
70
+ head_dim = hidden_size // num_heads
71
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
72
+
73
+ self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
74
+
75
+ # Self attention for audio + visual
76
+ self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
77
+ self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
78
+ self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
79
+ qk_norm_layer = get_norm_layer(qk_norm_type)
80
+ self.audio_self_q_norm = (
81
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
82
+ )
83
+ self.audio_self_k_norm = (
84
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
85
+ )
86
+ self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
87
+
88
+ # visual cond
89
+ self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
90
+ self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
91
+ self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
92
+ self.v_cond_attn_q_norm = (
93
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
94
+ )
95
+ self.v_cond_attn_k_norm = (
96
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
97
+ )
98
+ self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
99
+
100
+ self.max_text_len = 100
101
+ self.rope_dim_list = None
102
+
103
+ # audio and video norm for cross attention with text
104
+ self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
105
+ self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
106
+
107
+ # Cross attention: (video_audio) as query, text as key/value
108
+ self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
109
+ self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
110
+ self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
111
+
112
+ self.audio_cross_q_norm = (
113
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
114
+ )
115
+ self.v_cond_cross_q_norm = (
116
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
117
+ )
118
+ self.text_cross_k_norm = (
119
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
120
+ )
121
+ self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
122
+ self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
123
+
124
+ # MLPs
125
+ self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
126
+ self.audio_mlp = MLP(
127
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
128
+ )
129
+
130
+ self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
131
+ self.v_cond_mlp = MLP(
132
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
133
+ )
134
+
135
+ def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
136
+ target_ndim = 1 # n-d RoPE
137
+ rope_sizes = [text_len]
138
+
139
+ if rope_dim_list is None:
140
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
141
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
142
+
143
+ text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
144
+ rope_dim_list=rope_dim_list,
145
+ start=rope_sizes,
146
+ theta=10000,
147
+ use_real=True,
148
+ theta_rescale_factor=1.0,
149
+ )
150
+ return text_freqs_cos, text_freqs_sin
151
+
152
+ def set_attn_mode(self, new_mode):
153
+ if new_mode != "torch":
154
+ raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
155
+ self.attn_mode = new_mode
156
+
157
+ def enable_deterministic(self):
158
+ self.deterministic = True
159
+
160
+ def disable_deterministic(self):
161
+ self.deterministic = False
162
+
163
+ def forward(
164
+ self,
165
+ audio: torch.Tensor,
166
+ cond: torch.Tensor,
167
+ v_cond: torch.Tensor,
168
+ attn_mask: torch.Tensor,
169
+ vec: torch.Tensor,
170
+ freqs_cis: tuple = None,
171
+ v_freqs_cis: tuple = None,
172
+ sync_vec: torch.Tensor = None,
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174
+ # Get modulation parameters
175
+ if sync_vec is not None:
176
+ assert sync_vec.ndim == 3
177
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
178
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
179
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
180
+ ) = self.audio_mod(sync_vec).chunk(9, dim=-1)
181
+ else:
182
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
183
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
184
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
185
+ ) = self.audio_mod(vec).chunk(9, dim=-1)
186
+
187
+ (
188
+ v_cond_mod1_shift,
189
+ v_cond_mod1_scale,
190
+ v_cond_mod1_gate,
191
+ v_cond_mod2_shift,
192
+ v_cond_mod2_scale,
193
+ v_cond_mod2_gate,
194
+ v_cond_mod3_shift,
195
+ v_cond_mod3_scale,
196
+ v_cond_mod3_gate,
197
+ ) = self.v_cond_mod(vec).chunk(9, dim=-1)
198
+
199
+ # 1. Self Attention for audio + visual
200
+ audio_modulated = self.audio_norm1(audio)
201
+ audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
202
+ audio_qkv = self.audio_self_attn_qkv(audio_modulated)
203
+ audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
204
+ audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
205
+ audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
206
+
207
+ # Prepare visual cond for attention
208
+ v_cond_modulated = self.v_cond_norm1(v_cond)
209
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
210
+ v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
211
+ v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
212
+ v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
213
+ v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
214
+
215
+ # Apply RoPE if needed for audio and visual
216
+ if freqs_cis is not None:
217
+ if not self.interleaved_audio_visual_rope:
218
+ audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
219
+ audio_q, audio_k = audio_qq, audio_kk
220
+ else:
221
+ ori_audio_len = audio_q.shape[1]
222
+ ori_v_con_len = v_cond_q.shape[1]
223
+ interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
224
+ interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
225
+ interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
226
+ interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
227
+ )
228
+ audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
229
+ interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
230
+ )
231
+ audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
232
+ interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
233
+ )
234
+ audio_q, audio_k = audio_qq, audio_kk
235
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
236
+
237
+ # Apply RoPE to visual if needed and not interleaved
238
+ if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
239
+ v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
240
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
241
+
242
+ # Concatenate for self-attention
243
+ q = torch.cat((v_cond_q, audio_q), dim=1)
244
+ k = torch.cat((v_cond_k, audio_k), dim=1)
245
+ v = torch.cat((v_cond_v, audio_v), dim=1)
246
+
247
+ # Run self-attention
248
+ attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
249
+ v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
250
+
251
+ # Apply self-attention output to audio and v_cond
252
+ audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
253
+ v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
254
+
255
+ # 2. Cross Attention: (v_cond, audio) as query, text as key/value
256
+ # audio, v_cond modulation
257
+ audio_modulated = self.audio_norm2(audio)
258
+ audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
259
+ v_cond_modulated = self.v_cond_norm2(v_cond)
260
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
261
+
262
+ # Prepare audio query
263
+ audio_q = self.audio_cross_q(audio_modulated)
264
+ audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
265
+ audio_q = self.audio_cross_q_norm(audio_q)
266
+
267
+ # Prepare v_cond query
268
+ v_cond_q = self.v_cond_cross_q(v_cond_modulated)
269
+ v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
270
+ v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
271
+
272
+ # Prepare text key/value
273
+ text_kv = self.text_cross_kv(cond)
274
+ text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
275
+ text_k = self.text_cross_k_norm(text_k).to(text_v)
276
+
277
+ # Apply RoPE to (v_cond, audio) query and text key if needed
278
+ head_dim = self.hidden_size // self.num_heads
279
+ audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
280
+ audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
281
+ audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
282
+
283
+ v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
284
+ v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
285
+ v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
286
+
287
+ text_len = text_k.shape[1]
288
+
289
+ text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
290
+ rope_dim_list=self.rope_dim_list)
291
+ text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
292
+ text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
293
+
294
+ # Concat v_cond and audio for cross-attention
295
+ v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
296
+
297
+ # Run cross-attention
298
+ cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
299
+ v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
300
+
301
+ # Apply cross-attention output
302
+ audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
303
+ v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
304
+
305
+ # 3. Apply MLPs
306
+ audio = audio + apply_gate(
307
+ self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
308
+ gate=audio_mod3_gate,
309
+ )
310
+
311
+ # Apply visual MLP
312
+ v_cond = v_cond + apply_gate(
313
+ self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
314
+ gate=v_cond_mod3_gate,
315
+ )
316
+
317
+ return audio, cond, v_cond
318
+
319
+ class SingleStreamBlock(nn.Module):
320
+
321
+ def __init__(self, hidden_size: int,
322
+ num_heads: int,
323
+ mlp_ratio: float,
324
+ qk_norm_type: str = "rms",
325
+ dtype: Optional[torch.dtype] = None,
326
+ device: Optional[torch.device] = None,):
327
+ factory_kwargs = {"device": device, "dtype": dtype}
328
+ super().__init__()
329
+
330
+ self.hidden_size = hidden_size
331
+ self.num_heads = num_heads
332
+
333
+ self.modulation = ModulateDiT(
334
+ hidden_size=hidden_size,
335
+ factor=6,
336
+ act_layer=get_activation_layer("silu"),
337
+ **factory_kwargs,
338
+ )
339
+ self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
340
+ self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
341
+ self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
342
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
343
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
344
+ self.q_norm = nn.RMSNorm(hidden_size // num_heads)
345
+ self.k_norm = nn.RMSNorm(hidden_size // num_heads)
346
+ self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
347
+
348
+ def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
349
+ assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
350
+ modulation = self.modulation(cond)
351
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
352
+ x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
353
+
354
+ qkv = self.linear_qkv(x_norm1)
355
+ q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
356
+ q = q.squeeze(-1)
357
+ k = k.squeeze(-1)
358
+ v = v.squeeze(-1)
359
+
360
+ q = self.q_norm(q)
361
+ k = self.k_norm(k)
362
+ q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
363
+
364
+ q = q.contiguous()
365
+ k = k.contiguous()
366
+ v = v.contiguous()
367
+ out = F.scaled_dot_product_attention(q, k, v)
368
+ out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
369
+
370
+ x = x + apply_gate(self.linear1(out),gate=gate_msa)
371
+ x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
372
+ x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
373
+
374
+ return x
375
+
376
+ class HunyuanVideoFoley(ModelMixin, ConfigMixin):
377
+ @register_to_config
378
+ def __init__(
379
+ self,
380
+ model_config,
381
+ dtype: Optional[torch.dtype] = None,
382
+ device: Optional[torch.device] = None,
383
+ ):
384
+ factory_kwargs = {"device": device, "dtype": dtype}
385
+ super().__init__()
386
+
387
+ model_args = model_config.model_config.model_kwargs
388
+ self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
389
+ self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
390
+ # Gradient checkpoint.
391
+ self.gradient_checkpoint = False
392
+ self.gradient_checkpoint_layers = None
393
+ if self.gradient_checkpoint:
394
+ assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
395
+ f"Gradient checkpoint layers must be less or equal than the depth of the model. "
396
+ f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
397
+ )
398
+
399
+ self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
400
+
401
+ # Condition projection. Default to linear projection.
402
+ self.condition_projection = model_args.get("condition_projection", "linear")
403
+ self.condition_dim = model_args.get("condition_dim", None)
404
+ self.use_attention_mask = model_args.get("use_attention_mask", False)
405
+
406
+ self.patch_size = model_args.get("patch_size", 1)
407
+ self.visual_in_channels = model_args.get("clip_dim", 768)
408
+ self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
409
+ self.out_channels = self.audio_vae_latent_dim
410
+ self.unpatchify_channels = self.out_channels
411
+ self.reverse = model_args.get("reverse", False)
412
+
413
+ self.num_heads = model_args.get("num_heads", 24)
414
+ self.hidden_size = model_args.get("hidden_size", 3072)
415
+ self.rope_dim_list = model_args.get("rope_dim_list", None)
416
+ self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
417
+ self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
418
+
419
+ self.qkv_bias = model_args.get("qkv_bias", True)
420
+ self.qk_norm = model_args.get("qk_norm", True)
421
+ self.qk_norm_type = model_args.get("qk_norm_type", "rms")
422
+ self.attn_mode = model_args.get("attn_mode", "torch")
423
+
424
+ self.embedder_type = model_args.get("embedder_type", "default")
425
+
426
+ # sync condition things
427
+ self.sync_modulation = model_args.get("sync_modulation", False)
428
+ self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
429
+ self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
430
+ self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
431
+
432
+ # condition tokens length
433
+ self.clip_len = model_args.get("clip_length", 64)
434
+ self.sync_len = model_args.get("sync_length", 192)
435
+
436
+ if self.hidden_size % self.num_heads != 0:
437
+ raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
438
+
439
+ # Build audio patchify layer and visual gated linear projection
440
+ self.patch_size = 1
441
+ self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
442
+ self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
443
+
444
+ # condition
445
+ if self.condition_projection == "linear":
446
+ self.cond_in = ConditionProjection(
447
+ self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
448
+ )
449
+ else:
450
+ raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
451
+
452
+ # time modulation
453
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
454
+
455
+ # visual sync embedder if needed
456
+ if self.sync_in_ksz == 1:
457
+ sync_in_padding = 0
458
+ elif self.sync_in_ksz == 3:
459
+ sync_in_padding = 1
460
+ else:
461
+ raise ValueError
462
+ if self.sync_modulation or self.add_sync_feat_to_audio:
463
+ self.sync_in = nn.Sequential(
464
+ nn.Linear(self.sync_feat_dim, self.hidden_size),
465
+ nn.SiLU(),
466
+ ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
467
+ )
468
+ self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
469
+
470
+ self.triple_blocks = nn.ModuleList(
471
+ [
472
+ TwoStreamCABlock(
473
+ hidden_size=self.hidden_size,
474
+ num_heads=self.num_heads,
475
+ mlp_ratio=self.mlp_ratio,
476
+ mlp_act_type=self.mlp_act_type,
477
+ qk_norm=self.qk_norm,
478
+ qk_norm_type=self.qk_norm_type,
479
+ qkv_bias=self.qkv_bias,
480
+ attn_mode=self.attn_mode,
481
+ reverse=self.reverse,
482
+ interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
483
+ **factory_kwargs,
484
+ )
485
+ for _ in range(self.depth_triple_blocks)
486
+ ]
487
+ )
488
+
489
+
490
+ self.single_blocks = nn.ModuleList(
491
+ [
492
+ SingleStreamBlock(
493
+ hidden_size=self.hidden_size,
494
+ num_heads=self.num_heads,
495
+ mlp_ratio=self.mlp_ratio,
496
+ qk_norm_type=self.qk_norm_type,
497
+ **factory_kwargs,
498
+ )
499
+ for _ in range(self.depth_single_blocks)
500
+ ]
501
+ )
502
+
503
+ self.final_layer = FinalLayer1D(
504
+ self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
505
+ )
506
+ self.unpatchify_channels = self.out_channels
507
+
508
+ self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
509
+ self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
510
+ nn.init.constant_(self.empty_clip_feat, 0)
511
+ nn.init.constant_(self.empty_sync_feat, 0)
512
+
513
+ def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
514
+ if bs is None:
515
+ return self.empty_string_feat
516
+ else:
517
+ return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
518
+
519
+ def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
520
+ len = len if len is not None else self.clip_len
521
+ if bs is None:
522
+ return self.empty_clip_feat.expand(len, -1) # 15s
523
+ else:
524
+ return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
525
+
526
+ def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
527
+ len = len if len is not None else self.sync_len
528
+ if bs is None:
529
+ return self.empty_sync_feat.expand(len, -1)
530
+ else:
531
+ return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
532
+
533
+ def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
534
+ assert self.patch_size == 1
535
+ # ======================================== Build RoPE for audio tokens ======================================
536
+ target_ndim = 1 # n-d RoPE
537
+ rope_sizes = [audio_emb_len]
538
+ head_dim = self.hidden_size // self.num_heads
539
+ rope_dim_list = self.rope_dim_list
540
+ if rope_dim_list is None:
541
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
542
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
543
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
544
+ rope_dim_list=rope_dim_list,
545
+ start=rope_sizes,
546
+ theta=10000,
547
+ use_real=True,
548
+ theta_rescale_factor=1.0,
549
+ )
550
+
551
+ # ========================== Build RoPE for clip tokens =========================
552
+ target_ndim = 1 # n-d RoPE
553
+ rope_sizes = [visual_cond_len]
554
+ head_dim = self.hidden_size // self.num_heads
555
+ rope_dim_list = self.rope_dim_list
556
+ if rope_dim_list is None:
557
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
558
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
559
+ v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
560
+ rope_dim_list=rope_dim_list,
561
+ start=rope_sizes,
562
+ theta=10000,
563
+ use_real=True,
564
+ theta_rescale_factor=1.0,
565
+ freq_scaling=1.0 * audio_emb_len / visual_cond_len,
566
+ )
567
+ return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
568
+
569
+ def build_rope_for_interleaved_audio_visual(self, total_len):
570
+ assert self.patch_size == 1
571
+ # ========================== Build RoPE for audio tokens ========================
572
+ target_ndim = 1 # n-d RoPE
573
+ rope_sizes = [total_len]
574
+ head_dim = self.hidden_size // self.num_heads
575
+ rope_dim_list = self.rope_dim_list
576
+ if rope_dim_list is None:
577
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
578
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
579
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
580
+ rope_dim_list=rope_dim_list,
581
+ start=rope_sizes,
582
+ theta=10000,
583
+ use_real=True,
584
+ theta_rescale_factor=1.0,
585
+ )
586
+ return freqs_cos, freqs_sin
587
+
588
+ def set_attn_mode(self, new_mode):
589
+ for block in self.triple_blocks:
590
+ block.set_attn_mode(new_mode)
591
+ for block in self.single_blocks:
592
+ block.set_attn_mode(new_mode)
593
+
594
+ def enable_deterministic(self):
595
+ for block in self.triple_blocks:
596
+ block.enable_deterministic()
597
+ for block in self.single_blocks:
598
+ block.enable_deterministic()
599
+
600
+ def disable_deterministic(self):
601
+ for block in self.triple_blocks:
602
+ block.disable_deterministic()
603
+ for block in self.single_blocks:
604
+ block.disable_deterministic()
605
+
606
+ def forward(
607
+ self,
608
+ x: torch.Tensor,
609
+ t: torch.Tensor, # Should be in range(0, 1000).
610
+ clip_feat: Optional[torch.Tensor] = None,
611
+ cond: torch.Tensor = None,
612
+ audio_mask: Optional[torch.Tensor] = None,
613
+ cond_mask: torch.Tensor = None,
614
+ sync_feat: Optional[torch.Tensor] = None,
615
+ drop_visual: Optional[List[bool]] = None,
616
+ return_dict: bool = True,
617
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
618
+ out = {}
619
+ audio = x
620
+ bs, _, ol = x.shape
621
+ tl = ol // self.patch_size
622
+
623
+ # Prepare learnable empty conditions for visual condition
624
+ if drop_visual is not None:
625
+ clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
626
+ sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
627
+
628
+ # ========================= Prepare time & visual modulation =========================
629
+ vec = self.time_in(t)
630
+ sync_vec = None
631
+ if self.sync_modulation:
632
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
633
+ sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
634
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
635
+ sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
636
+ sync_vec = (
637
+ F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
638
+ ) # bs, tl, c
639
+ sync_vec = sync_vec + vec.unsqueeze(1)
640
+ elif self.add_sync_feat_to_audio:
641
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
642
+ sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
643
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
644
+ sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
645
+ add_sync_feat_to_audio = (
646
+ F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
647
+ ) # bs, tl, c
648
+
649
+ # ========================= Get text, audio and video clip embedding =========================
650
+ cond = self.cond_in(cond)
651
+ cond_seq_len = cond.shape[1]
652
+
653
+ audio = self.audio_embedder(x)
654
+ audio_seq_len = audio.shape[1]
655
+ v_cond = self.visual_proj(clip_feat)
656
+ v_cond_seq_len = v_cond.shape[1]
657
+
658
+ # ========================= Compute attention mask =========================
659
+ attn_mask = None
660
+ if self.use_attention_mask:
661
+ assert cond_mask is not None
662
+ batch_size = audio.shape[0]
663
+ seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
664
+
665
+ # get default audio_mask and v_cond_mask
666
+ audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
667
+ v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
668
+
669
+ # batch_size x seq_len
670
+ concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
671
+ # batch_size x 1 x seq_len x seq_len
672
+ attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
673
+ # batch_size x 1 x seq_len x seq_len
674
+ attn_mask_2 = attn_mask_1.transpose(2, 3)
675
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
676
+ attn_mask = (attn_mask_1 & attn_mask_2).bool()
677
+ # avoids self-attention weight being NaN for text padding tokens
678
+ attn_mask[:, :, :, 0] = True
679
+
680
+
681
+ # ========================= Build rope for audio and clip tokens =========================
682
+ if self.interleaved_audio_visual_rope:
683
+ freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
684
+ v_freqs_cos = v_freqs_sin = None
685
+ else:
686
+ freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
687
+ audio_seq_len, v_cond_seq_len
688
+ )
689
+
690
+ # ========================= Pass through DiT blocks =========================
691
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
692
+ v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
693
+
694
+ if self.add_sync_feat_to_audio:
695
+ add_sync_layer = 0
696
+ assert (
697
+ add_sync_layer < self.depth_triple_blocks
698
+ ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
699
+ # Triple-stream blocks
700
+ for layer_num, block in enumerate(self.triple_blocks):
701
+ if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
702
+ audio = audio + add_sync_feat_to_audio
703
+ triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
704
+ if (
705
+ self.training
706
+ and self.gradient_checkpoint
707
+ and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
708
+ ):
709
+ audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
710
+ ckpt_wrapper(block), *triple_block_args, use_reentrant=False
711
+ )
712
+ else:
713
+ audio, cond, v_cond = block(*triple_block_args)
714
+
715
+ x = audio
716
+ if sync_vec is not None:
717
+ vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
718
+ vec = torch.cat((vec, sync_vec), dim=1)
719
+
720
+ freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
721
+ if self.add_sync_feat_to_audio:
722
+ vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
723
+ if len(self.single_blocks) > 0:
724
+ for layer_num, block in enumerate(self.single_blocks):
725
+ single_block_args = [
726
+ x,
727
+ vec,
728
+ (freqs_cos, freqs_sin),
729
+ ]
730
+ if (
731
+ self.training
732
+ and self.gradient_checkpoint
733
+ and (
734
+ self.gradient_checkpoint_layers == -1
735
+ or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
736
+ )
737
+ ):
738
+ x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
739
+ else:
740
+ x = block(*single_block_args)
741
+
742
+ audio = x
743
+
744
+ # ========================= Final layer =========================
745
+ if sync_vec is not None:
746
+ vec = sync_vec
747
+ audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
748
+ audio = self.unpatchify1d(audio, tl)
749
+
750
+ if return_dict:
751
+ out["x"] = audio
752
+ return out
753
+ return audio
754
+
755
+ def unpatchify1d(self, x, l):
756
+ # x: (N, L, patch_size * C)
757
+ # audio: (N, C, T), T == L * patch_size
758
+ c = self.unpatchify_channels
759
+ p = self.patch_size
760
+ assert l == x.shape[1]
761
+
762
+ x = x.reshape(shape=(x.shape[0], l, p, c))
763
+ x = torch.einsum("ntpc->nctp", x)
764
+ audio = x.reshape(shape=(x.shape[0], c, l * p))
765
+ return audio
766
+
767
+ def params_count(self):
768
+ counts = {
769
+ "triple": sum(
770
+ [
771
+ sum(p.numel() for p in block.audio_cross_q.parameters())
772
+ + sum(p.numel() for p in block.v_cond_cross_q.parameters())
773
+ + sum(p.numel() for p in block.text_cross_kv.parameters())
774
+ + sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
775
+ + sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
776
+ + sum(p.numel() for p in block.audio_mlp.parameters())
777
+ + sum(p.numel() for p in block.audio_self_proj.parameters())
778
+ + sum(p.numel() for p in block.v_cond_self_proj.parameters())
779
+ + sum(p.numel() for p in block.v_cond_mlp.parameters())
780
+ for block in self.triple_blocks
781
+ ]
782
+ ),
783
+ "single": sum(
784
+ [
785
+ sum(p.numel() for p in block.linear1.parameters())
786
+ + sum(p.numel() for p in block.linear2.parameters())
787
+ for block in self.single_blocks
788
+ ]
789
+ ),
790
+ "total": sum(p.numel() for p in self.parameters()),
791
+ }
792
+
793
+ counts["attn+mlp"] = counts["triple"] + counts["single"]
794
+ return counts
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/__init__.py ADDED
File without changes
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/activation_layers.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ def get_activation_layer(act_type):
5
+ if act_type == "gelu":
6
+ return lambda: nn.GELU()
7
+ elif act_type == "gelu_tanh":
8
+ # Approximate `tanh` requires torch >= 1.13
9
+ return lambda: nn.GELU(approximate="tanh")
10
+ elif act_type == "relu":
11
+ return nn.ReLU
12
+ elif act_type == "silu":
13
+ return nn.SiLU
14
+ else:
15
+ raise ValueError(f"Unknown activation type: {act_type}")
16
+
17
+ class SwiGLU(nn.Module):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ hidden_dim: int,
22
+ out_dim: int,
23
+ ):
24
+ """
25
+ Initialize the SwiGLU FeedForward module.
26
+
27
+ Args:
28
+ dim (int): Input dimension.
29
+ hidden_dim (int): Hidden dimension of the feedforward layer.
30
+
31
+ Attributes:
32
+ w1: Linear transformation for the first layer.
33
+ w2: Linear transformation for the second layer.
34
+ w3: Linear transformation for the third layer.
35
+
36
+ """
37
+ super().__init__()
38
+
39
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
40
+ self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
41
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
42
+
43
+ def forward(self, x):
44
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/attn_layers.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+ from typing import Tuple, Union
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+ try:
10
+ from flash_attn import (
11
+ flash_attn_qkvpacked_func,
12
+ flash_attn_kvpacked_func,
13
+ flash_attn_varlen_kvpacked_func,
14
+ flash_attn_varlen_qkvpacked_func,
15
+ )
16
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
17
+ except ImportError:
18
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
19
+ index_first_axis = None
20
+ from packaging import version
21
+ from transformers.utils.import_utils import _is_package_available
22
+
23
+ from .norm_layers import get_norm_layer
24
+
25
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
26
+ """
27
+ Reshape frequency tensor for broadcasting it with another tensor.
28
+
29
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
30
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
31
+
32
+ Notes:
33
+ When using FlashMHAModified, head_first should be False.
34
+ When using Attention, head_first should be True.
35
+
36
+ Args:
37
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
38
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
39
+ head_first (bool): head dimension first (except batch dim) or not.
40
+
41
+ Returns:
42
+ torch.Tensor: Reshaped frequency tensor.
43
+
44
+ Raises:
45
+ AssertionError: If the frequency tensor doesn't match the expected shape.
46
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
47
+ """
48
+ ndim = x.ndim
49
+ assert 0 <= 1 < ndim
50
+
51
+ if isinstance(freqs_cis, tuple):
52
+ # freqs_cis: (cos, sin) in real space
53
+ if head_first:
54
+ assert freqs_cis[0].shape == (
55
+ x.shape[-2],
56
+ x.shape[-1],
57
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
58
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
59
+ else:
60
+ assert freqs_cis[0].shape == (
61
+ x.shape[1],
62
+ x.shape[-1],
63
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
64
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
65
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
66
+ else:
67
+ # freqs_cis: values in complex space
68
+ if head_first:
69
+ assert freqs_cis.shape == (
70
+ x.shape[-2],
71
+ x.shape[-1],
72
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
73
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
74
+ else:
75
+ assert freqs_cis.shape == (
76
+ x.shape[1],
77
+ x.shape[-1],
78
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
79
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
80
+ return freqs_cis.view(*shape)
81
+
82
+
83
+ def rotate_half(x):
84
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
85
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
86
+
87
+
88
+ def apply_rotary_emb(
89
+ xq: torch.Tensor,
90
+ xk: torch.Tensor,
91
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
92
+ head_first: bool = False,
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """
95
+ Apply rotary embeddings to input tensors using the given frequency tensor.
96
+
97
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
98
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
99
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
100
+ returned as real tensors.
101
+
102
+ Args:
103
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
104
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
105
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
106
+ head_first (bool): head dimension first (except batch dim) or not.
107
+
108
+ Returns:
109
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
110
+
111
+ """
112
+ xk_out = None
113
+ if isinstance(freqs_cis, tuple):
114
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
115
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
116
+ # real * cos - imag * sin
117
+ # imag * cos + real * sin
118
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
119
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
120
+ else:
121
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
122
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
123
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
124
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
125
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
126
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
127
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
128
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
129
+
130
+ return xq_out, xk_out
131
+
132
+
133
+ class BasicAttentionLayer(nn.Module):
134
+ def __init__(self, attn_mode="flash", deterministic=False):
135
+ super().__init__()
136
+ self.attn_mode = attn_mode
137
+ self.deterministic = deterministic
138
+
139
+ def set_attn_mode(self, new_mode):
140
+ self.attn_mode = new_mode
141
+
142
+ def enable_deterministic(self):
143
+ self.deterministic = True
144
+
145
+ def disable_deterministic(self):
146
+ self.deterministic = False
147
+
148
+
149
+ MEMORY_LAYOUT = {
150
+ "self_flash": (
151
+ lambda x: x,
152
+ lambda x: x,
153
+ ),
154
+ "cross_flash": (
155
+ lambda x: x,
156
+ lambda x: x,
157
+ ),
158
+ "flash_torch_sp": (
159
+ lambda x: x,
160
+ lambda x: x,
161
+ ),
162
+ "torch": (
163
+ lambda x: x.transpose(1, 2),
164
+ lambda x: x.transpose(1, 2),
165
+ ),
166
+ "vanilla": (
167
+ lambda x: x.transpose(1, 2),
168
+ lambda x: x.transpose(1, 2),
169
+ ),
170
+ }
171
+
172
+
173
+ # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
174
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
175
+ """
176
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
177
+
178
+ Arguments:
179
+ attention_mask (`torch.Tensor`):
180
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
181
+
182
+ Return:
183
+ indices (`torch.Tensor):
184
+ The indices of non-masked tokens from the flattened input sequence.
185
+ cu_seqlens (`torch.Tensor`):
186
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
187
+ max_seqlen_in_batch (`int`):
188
+ Maximum sequence length in batch.
189
+ """
190
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
191
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
192
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
193
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
194
+ return (
195
+ indices,
196
+ cu_seqlens,
197
+ max_seqlen_in_batch,
198
+ )
199
+
200
+
201
+ # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
202
+ def is_flash_attn_greater_or_equal(library_version: str):
203
+ if not _is_package_available("flash_attn"):
204
+ return False
205
+
206
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
207
+
208
+
209
+ def get_kv_seqlens_with_mask(attn_mask, k, v):
210
+ indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
211
+ b, s1, a, d = k.shape
212
+ k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
213
+ v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
214
+ kv = torch.stack([k, v], dim=1)
215
+ return cu_seqlens_k, max_seqlen_k, kv
216
+
217
+
218
+ def get_q_seqlens(q):
219
+ bs, s, a, d = q.shape
220
+ cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
221
+ q = q.reshape(bs * s, a, d)
222
+ return cu_seqlens_q, s, q
223
+
224
+ def flash_attn_no_pad(
225
+ qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
226
+ ):
227
+ # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
228
+ batch_size = qkv.shape[0]
229
+ seqlen = qkv.shape[1]
230
+ nheads = qkv.shape[-2]
231
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
232
+ # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
233
+ # x_unpad, indices, cu_seqlens, max_s
234
+ unpad_results = unpad_input(
235
+ x, key_padding_mask
236
+ )
237
+
238
+ if len(unpad_results) == 4:
239
+ x_unpad, indices, cu_seqlens, max_s = unpad_results
240
+ elif len(unpad_results) == 5:
241
+ x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
242
+ else:
243
+ raise ValueError
244
+
245
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
246
+ output_unpad = flash_attn_varlen_qkvpacked_func(
247
+ x_unpad,
248
+ cu_seqlens,
249
+ max_s,
250
+ dropout_p,
251
+ softmax_scale=softmax_scale,
252
+ causal=causal,
253
+ )
254
+ output = rearrange(
255
+ pad_input(
256
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
257
+ ),
258
+ "b s (h d) -> b s h d",
259
+ h=nheads,
260
+ )
261
+ return output
262
+
263
+
264
+ def attention(
265
+ q,
266
+ k,
267
+ v,
268
+ mode,
269
+ drop_rate=0,
270
+ attn_mask=None,
271
+ cond_mask=None,
272
+ causal=False,
273
+ deterministic=False,
274
+ cu_seqlens=None,
275
+ max_seqlen=None,
276
+ cu_seqlens_k=None,
277
+ max_seqlen_k=None,
278
+ img_seq_len=None,
279
+ ):
280
+ """
281
+ Perform QKV self attention.
282
+
283
+ Args:
284
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
285
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
286
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
287
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
288
+ drop_rate (float): Dropout rate in attention map. (default: 0)
289
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
290
+ (default: None)
291
+ causal (bool): Whether to use causal attention. (default: False)
292
+ deterministic (bool): Whether to use deterministic attention. (default: False)
293
+ cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
294
+ used to index into q.
295
+ max_seqlen (int): The maximum sequence length in the batch of q.
296
+ cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
297
+ used to index into kv.
298
+ max_seqlen_k (int): The maximum sequence length in the batch of k and v.
299
+
300
+ Returns:
301
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
302
+ """
303
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
304
+ if isinstance(q, tuple):
305
+ q = torch.cat(q, dim=1)
306
+ if isinstance(k, tuple):
307
+ k = torch.cat(k, dim=1)
308
+ if isinstance(v, tuple):
309
+ v = torch.cat(v, dim=1)
310
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
311
+ q = pre_attn_layout(q)
312
+ k = pre_attn_layout(k)
313
+ v = pre_attn_layout(v)
314
+
315
+ if "flash" in mode:
316
+ assert (
317
+ flash_attn_qkvpacked_func is not None
318
+ ), "Flash attention is not available. Please install flash_attn first."
319
+ flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
320
+ if deterministic:
321
+ if not is_flash_attn_greater_or_equal("2.4.1"):
322
+ raise ValueError(
323
+ "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
324
+ )
325
+ flash_kwargs["deterministic"] = deterministic
326
+
327
+ if mode == "self_flash":
328
+ qkv = torch.stack([q, k, v], dim=2)
329
+ if attn_mask is not None:
330
+ raise ValueError("Self attention does not support attention mask")
331
+ x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
332
+
333
+ elif mode == "cross_flash":
334
+ kv = torch.stack([k, v], dim=2)
335
+ if attn_mask is None:
336
+ x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
337
+ else:
338
+ b, s, a, h = q.shape
339
+ cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
340
+ cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
341
+
342
+ attn_output = flash_attn_varlen_kvpacked_func(
343
+ q,
344
+ kv,
345
+ cu_seqlens_q=cu_seqlens_q,
346
+ cu_seqlens_k=cu_seqlens_k,
347
+ max_seqlen_q=max_seqlen_q,
348
+ max_seqlen_k=max_seqlen_k,
349
+ **flash_kwargs,
350
+ )
351
+ x = attn_output.reshape(b, s, a, h)
352
+ elif mode == 'torch':
353
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
354
+ attn_mask = attn_mask.to(q.dtype)
355
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
356
+
357
+ elif mode == "vanilla":
358
+ scale_factor = 1 / math.sqrt(q.size(-1))
359
+
360
+ b, a, s, _ = q.shape
361
+ s1 = k.size(2)
362
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
363
+ if causal:
364
+ # Only applied to self attention
365
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
366
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
367
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
368
+ attn_bias.to(q.dtype)
369
+
370
+ if attn_mask is not None:
371
+ if attn_mask.dtype == torch.bool:
372
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
373
+ else:
374
+ attn_bias += attn_mask
375
+
376
+ # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
377
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
378
+ attn += attn_bias
379
+ attn = attn.softmax(dim=-1)
380
+ attn = torch.dropout(attn, p=drop_rate, train=True)
381
+ x = attn @ v
382
+ else:
383
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
384
+
385
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
386
+ x = post_attn_layout(x).contiguous()
387
+ b, s, a, d = x.shape
388
+ out = x.reshape(b, s, -1)
389
+ return out
390
+
391
+
392
+ class SelfAttentionLayer(BasicAttentionLayer):
393
+ def __init__(
394
+ self,
395
+ dim,
396
+ num_heads,
397
+ qkv_bias=True,
398
+ qk_norm=True,
399
+ attn_drop=0,
400
+ proj_drop=0,
401
+ dtype=None,
402
+ device=None,
403
+ norm_type="layer",
404
+ attn_mode="self_flash",
405
+ deterministic=False,
406
+ ) -> None:
407
+ factory_kwargs = {"device": device, "dtype": dtype}
408
+ super().__init__(attn_mode, deterministic)
409
+ self.dim = dim
410
+ self.num_heads = num_heads
411
+ assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
412
+ self.head_dim = self.dim // num_heads
413
+ self.attn_drop = attn_drop
414
+
415
+ # This assertion is aligned with flash attention
416
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
417
+
418
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
419
+
420
+ norm_layer = get_norm_layer(norm_type)
421
+ self.q_norm = (
422
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
423
+ )
424
+ self.k_norm = (
425
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
426
+ )
427
+
428
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
429
+ self.proj_drop = nn.Dropout(proj_drop)
430
+
431
+ def forward(self, x, freqs_cis=None, attn_mask=None):
432
+ """
433
+ Args:
434
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
435
+ freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
436
+ attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
437
+ """
438
+ b, s, d = x.shape
439
+
440
+ # Apply QKV projection
441
+ qkv = self.Wqkv(x)
442
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
443
+ q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
444
+
445
+ # Apply QK-Norm if needed
446
+ q = self.q_norm(q)
447
+ k = self.k_norm(k)
448
+
449
+ # Apply RoPE if needed
450
+ if freqs_cis is not None:
451
+ qq, kk = apply_rotary_emb(q, k, freqs_cis)
452
+ assert (
453
+ qq.shape == q.shape and kk.shape == k.shape
454
+ ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
455
+ q, k = qq, kk
456
+
457
+ # Apply self attention
458
+ context = attention(
459
+ q,
460
+ k,
461
+ v,
462
+ drop_rate=self.attn_drop if self.training else 0,
463
+ attn_mask=attn_mask,
464
+ mode=self.attn_mode,
465
+ deterministic=self.deterministic,
466
+ )
467
+ out = self.out_proj(context)
468
+ out = self.proj_drop(out)
469
+
470
+ return out
471
+
472
+
473
+ class CrossAttentionLayer(BasicAttentionLayer):
474
+ def __init__(
475
+ self,
476
+ qdim,
477
+ kdim,
478
+ num_heads,
479
+ qkv_bias=True,
480
+ qk_norm=True,
481
+ attn_drop=0,
482
+ proj_drop=0,
483
+ dtype=None,
484
+ device=None,
485
+ norm_type="layer",
486
+ attn_mode="cross_flash",
487
+ deterministic=False,
488
+ ):
489
+ factory_kwargs = {"device": device, "dtype": dtype}
490
+ super().__init__(attn_mode, deterministic)
491
+ self.qdim = qdim
492
+ self.kdim = kdim
493
+ self.num_heads = num_heads
494
+ assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
495
+ self.head_dim = self.qdim // num_heads
496
+ self.attn_drop = attn_drop
497
+
498
+ # This assertion is aligned with flash attention
499
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
500
+
501
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
502
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
503
+
504
+ norm_layer = get_norm_layer(norm_type)
505
+ self.q_norm = (
506
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
507
+ )
508
+ self.k_norm = (
509
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
510
+ )
511
+
512
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
513
+ self.proj_drop = nn.Dropout(proj_drop)
514
+
515
+ def forward(self, x, y, attn_mask=None):
516
+ """
517
+ Args:
518
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
519
+ y (torch.Tensor): (batch, seq_len1, hidden_dim1)
520
+ attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
521
+ """
522
+ b, s, d = x.shape
523
+ _, s1, d1 = y.shape
524
+
525
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
526
+ kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
527
+ k, v = kv.unbind(dim=2)
528
+
529
+ # Apply QK-Norm if needed
530
+ q = self.q_norm(q)
531
+ k = self.k_norm(k)
532
+
533
+ # Apply cross attention
534
+ context = attention(
535
+ q,
536
+ k,
537
+ v,
538
+ attn_mask=attn_mask,
539
+ drop_rate=self.attn_drop if self.training else 0,
540
+ mode=self.attn_mode,
541
+ deterministic=self.deterministic,
542
+ )
543
+ out = self.out_proj(context)
544
+ out = self.proj_drop(out)
545
+
546
+ return out
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/embed_layers.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from ...utils.helper import to_2tuple, to_1tuple
6
+
7
+ class PatchEmbed1D(nn.Module):
8
+ """1D Audio to Patch Embedding
9
+
10
+ A convolution based approach to patchifying a 1D audio w/ embedding projection.
11
+
12
+ Based on the impl in https://github.com/google-research/vision_transformer
13
+
14
+ Hacked together by / Copyright 2020 Ross Wightman
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ patch_size=1,
20
+ in_chans=768,
21
+ embed_dim=768,
22
+ norm_layer=None,
23
+ flatten=True,
24
+ bias=True,
25
+ dtype=None,
26
+ device=None,
27
+ ):
28
+ factory_kwargs = {"dtype": dtype, "device": device}
29
+ super().__init__()
30
+ patch_size = to_1tuple(patch_size)
31
+ self.patch_size = patch_size
32
+ self.flatten = flatten
33
+
34
+ self.proj = nn.Conv1d(
35
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
36
+ )
37
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
38
+ if bias:
39
+ nn.init.zeros_(self.proj.bias)
40
+
41
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
42
+
43
+ def forward(self, x):
44
+ assert (
45
+ x.shape[2] % self.patch_size[0] == 0
46
+ ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
47
+
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.transpose(1, 2) # BCN -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class ConditionProjection(nn.Module):
56
+ """
57
+ Projects condition embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {'dtype': dtype, 'device': device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(
92
+ -math.log(max_period)
93
+ * torch.arange(start=0, end=half, dtype=torch.float32)
94
+ / half
95
+ ).to(device=t.device)
96
+ args = t[:, None].float() * freqs[None]
97
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98
+ if dim % 2:
99
+ embedding = torch.cat(
100
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
101
+ )
102
+ return embedding
103
+
104
+
105
+ class TimestepEmbedder(nn.Module):
106
+ """
107
+ Embeds scalar timesteps into vector representations.
108
+ """
109
+ def __init__(self,
110
+ hidden_size,
111
+ act_layer,
112
+ frequency_embedding_size=256,
113
+ max_period=10000,
114
+ out_size=None,
115
+ dtype=None,
116
+ device=None
117
+ ):
118
+ factory_kwargs = {'dtype': dtype, 'device': device}
119
+ super().__init__()
120
+ self.frequency_embedding_size = frequency_embedding_size
121
+ self.max_period = max_period
122
+ if out_size is None:
123
+ out_size = hidden_size
124
+
125
+ self.mlp = nn.Sequential(
126
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
127
+ act_layer(),
128
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
129
+ )
130
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
131
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
132
+
133
+ def forward(self, t):
134
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
135
+ t_emb = self.mlp(t_freq)
136
+ return t_emb
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/mlp_layers.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .modulate_layers import modulate
11
+ from ...utils.helper import to_2tuple
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
38
+ self.act = act_layer()
39
+ self.drop1 = nn.Dropout(drop_probs[0])
40
+ self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
41
+ self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
42
+ self.drop2 = nn.Dropout(drop_probs[1])
43
+
44
+ def forward(self, x):
45
+ x = self.fc1(x)
46
+ x = self.act(x)
47
+ x = self.drop1(x)
48
+ x = self.norm(x)
49
+ x = self.fc2(x)
50
+ x = self.drop2(x)
51
+ return x
52
+
53
+
54
+ # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
55
+ # only used when use_vanilla is True
56
+ class MLPEmbedder(nn.Module):
57
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
58
+ factory_kwargs = {"device": device, "dtype": dtype}
59
+ super().__init__()
60
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
61
+ self.silu = nn.SiLU()
62
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ return self.out_layer(self.silu(self.in_layer(x)))
66
+
67
+
68
+ class LinearWarpforSingle(nn.Module):
69
+ def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
70
+ factory_kwargs = {"device": device, "dtype": dtype}
71
+ super().__init__()
72
+ self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
73
+
74
+ def forward(self, x, y):
75
+ z = torch.cat([x, y], dim=2)
76
+ return self.fc(z)
77
+
78
+ class FinalLayer1D(nn.Module):
79
+ def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
80
+ factory_kwargs = {"device": device, "dtype": dtype}
81
+ super().__init__()
82
+
83
+ # Just use LayerNorm for the final layer
84
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
85
+ self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
86
+ nn.init.zeros_(self.linear.weight)
87
+ nn.init.zeros_(self.linear.bias)
88
+
89
+ # Here we don't distinguish between the modulate types. Just use the simple one.
90
+ self.adaLN_modulation = nn.Sequential(
91
+ act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
92
+ )
93
+ # Zero-initialize the modulation
94
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
95
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
96
+
97
+ def forward(self, x, c):
98
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
99
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
100
+ x = self.linear(x)
101
+ return x
102
+
103
+
104
+ class ChannelLastConv1d(nn.Conv1d):
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ x = x.permute(0, 2, 1)
108
+ x = super().forward(x)
109
+ x = x.permute(0, 2, 1)
110
+ return x
111
+
112
+
113
+ class ConvMLP(nn.Module):
114
+
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ hidden_dim: int,
119
+ multiple_of: int = 256,
120
+ kernel_size: int = 3,
121
+ padding: int = 1,
122
+ device=None,
123
+ dtype=None,
124
+ ):
125
+ """
126
+ Convolutional MLP module.
127
+
128
+ Args:
129
+ dim (int): Input dimension.
130
+ hidden_dim (int): Hidden dimension of the feedforward layer.
131
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
132
+
133
+ Attributes:
134
+ w1: Linear transformation for the first layer.
135
+ w2: Linear transformation for the second layer.
136
+ w3: Linear transformation for the third layer.
137
+
138
+ """
139
+ factory_kwargs = {"device": device, "dtype": dtype}
140
+ super().__init__()
141
+ hidden_dim = int(2 * hidden_dim / 3)
142
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
143
+
144
+ self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
145
+ self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
146
+ self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
147
+
148
+ def forward(self, x):
149
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/modulate_layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ModulateDiT(nn.Module):
6
+ def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
7
+ factory_kwargs = {"dtype": dtype, "device": device}
8
+ super().__init__()
9
+ self.act = act_layer()
10
+ self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
11
+ # Zero-initialize the modulation
12
+ nn.init.zeros_(self.linear.weight)
13
+ nn.init.zeros_(self.linear.bias)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ return self.linear(self.act(x))
17
+
18
+
19
+ def modulate(x, shift=None, scale=None):
20
+ if x.ndim == 3:
21
+ shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
22
+ scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
23
+ if scale is None and shift is None:
24
+ return x
25
+ elif shift is None:
26
+ return x * (1 + scale)
27
+ elif scale is None:
28
+ return x + shift
29
+ else:
30
+ return x * (1 + scale) + shift
31
+
32
+
33
+ def apply_gate(x, gate=None, tanh=False):
34
+ if gate is None:
35
+ return x
36
+ if gate.ndim == 2 and x.ndim == 3:
37
+ gate = gate.unsqueeze(1)
38
+ if tanh:
39
+ return x * gate.tanh()
40
+ else:
41
+ return x * gate
42
+
43
+
44
+ def ckpt_wrapper(module):
45
+ def ckpt_forward(*inputs):
46
+ outputs = module(*inputs)
47
+ return outputs
48
+
49
+ return ckpt_forward
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/norm_layers.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RMSNorm(nn.Module):
5
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
6
+ device=None, dtype=None):
7
+ """
8
+ Initialize the RMSNorm normalization layer.
9
+
10
+ Args:
11
+ dim (int): The dimension of the input tensor.
12
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
+
14
+ Attributes:
15
+ eps (float): A small value added to the denominator for numerical stability.
16
+ weight (nn.Parameter): Learnable scaling parameter.
17
+
18
+ """
19
+ factory_kwargs = {'device': device, 'dtype': dtype}
20
+ super().__init__()
21
+ self.eps = eps
22
+ if elementwise_affine:
23
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
24
+
25
+ def _norm(self, x):
26
+ """
27
+ Apply the RMSNorm normalization to the input tensor.
28
+
29
+ Args:
30
+ x (torch.Tensor): The input tensor.
31
+
32
+ Returns:
33
+ torch.Tensor: The normalized tensor.
34
+
35
+ """
36
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
37
+
38
+ def forward(self, x):
39
+ """
40
+ Forward pass through the RMSNorm layer.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input tensor.
44
+
45
+ Returns:
46
+ torch.Tensor: The output tensor after applying RMSNorm.
47
+
48
+ """
49
+ output = self._norm(x.float()).type_as(x)
50
+ if hasattr(self, "weight"):
51
+ output = output * self.weight
52
+ return output
53
+
54
+
55
+ def get_norm_layer(norm_layer):
56
+ """
57
+ Get the normalization layer.
58
+
59
+ Args:
60
+ norm_layer (str): The type of normalization layer.
61
+
62
+ Returns:
63
+ norm_layer (nn.Module): The normalization layer.
64
+ """
65
+ if norm_layer == "layer":
66
+ return nn.LayerNorm
67
+ elif norm_layer == "rms":
68
+ return RMSNorm
69
+ else:
70
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
HunyuanVideo-Foley/hunyuanvideo_foley/models/nn/posemb_layers.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def get_nd_rotary_pos_embed(
66
+ rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
67
+ ):
68
+ """
69
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
70
+
71
+ Args:
72
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
73
+ sum(rope_dim_list) should equal to head_dim of attention layer.
74
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
75
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
76
+ *args: See above.
77
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
78
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
79
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
80
+ part and an imaginary part separately.
81
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
82
+ freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
83
+
84
+ Returns:
85
+ pos_embed (torch.Tensor): [HW, D/2]
86
+ """
87
+
88
+ grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
89
+
90
+ # use 1/ndim of dimensions to encode grid_axis
91
+ embs = []
92
+ for i in range(len(rope_dim_list)):
93
+ emb = get_1d_rotary_pos_embed(
94
+ rope_dim_list[i],
95
+ grid[i].reshape(-1),
96
+ theta,
97
+ use_real=use_real,
98
+ theta_rescale_factor=theta_rescale_factor,
99
+ freq_scaling=freq_scaling,
100
+ ) # 2 x [WHD, rope_dim_list[i]]
101
+ embs.append(emb)
102
+
103
+ if use_real:
104
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
105
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
106
+ return cos, sin
107
+ else:
108
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
109
+ return emb
110
+
111
+
112
+ def get_1d_rotary_pos_embed(
113
+ dim: int,
114
+ pos: Union[torch.FloatTensor, int],
115
+ theta: float = 10000.0,
116
+ use_real: bool = False,
117
+ theta_rescale_factor: float = 1.0,
118
+ freq_scaling: float = 1.0,
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
120
+ """
121
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
122
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
123
+
124
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
125
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
126
+ The returned tensor contains complex values in complex64 data type.
127
+
128
+ Args:
129
+ dim (int): Dimension of the frequency tensor.
130
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
131
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
132
+ use_real (bool, optional): If True, return real part and imaginary part separately.
133
+ Otherwise, return complex numbers.
134
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
135
+ freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
136
+
137
+ Returns:
138
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
139
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
140
+ """
141
+ if isinstance(pos, int):
142
+ pos = torch.arange(pos).float()
143
+
144
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
145
+ # has some connection to NTK literature
146
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
147
+ if theta_rescale_factor != 1.0:
148
+ theta *= theta_rescale_factor ** (dim / (dim - 1))
149
+
150
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
151
+ freqs *= freq_scaling
152
+ freqs = torch.outer(pos, freqs) # [S, D/2]
153
+ if use_real:
154
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
155
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
156
+ return freqs_cos, freqs_sin
157
+ else:
158
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
159
+ return freqs_cis
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .synchformer import Synchformer
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/ast_model.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
5
+
6
+ from .modeling_ast import ASTForAudioClassification, ASTConfig
7
+ from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer
8
+ from .utils import check_if_file_exists_else_download
9
+
10
+
11
+ class AST(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ extract_features: bool = False,
15
+ ckpt_path: str = None,
16
+ feat_type: str = None,
17
+ max_spec_t: int = None,
18
+ factorize_freq_time: bool = None,
19
+ agg_freq_module: str = None,
20
+ agg_time_module: str = None,
21
+ add_global_repr: bool = True,
22
+ agg_segments_module: str = None,
23
+ max_segments: int = None,
24
+ ) -> None:
25
+ """
26
+ extract_features: if True, then the model will return the features instead of head's output
27
+ ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
28
+ feat_type: if extract_features is True, this parameter specifies the type of features to return
29
+ max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
30
+ factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
31
+ agg_freq_module: if specified, then the model will use this module for freq aggregation
32
+ agg_time_module: if specified, then the model will use this module for time aggregation
33
+ add_global_repr: if True, adds a global representation to the features (aggregation on segments)
34
+ agg_segments_module: if specified, then the model will use this module for segments aggregation
35
+ max_segments: if specified, the initialization of PE in the global agg module will use this value.
36
+ This should correspond to the max number of segments per video (if None, 16 is used)
37
+ """
38
+ super().__init__()
39
+ self.extract_features = extract_features
40
+ self.ckpt_path = ckpt_path
41
+ self.max_spec_t = max_spec_t
42
+ self.max_segments = max_segments
43
+
44
+ # depending on whether the feat extractor was pre-trained contrastively or not, we need to
45
+ # load the state dict differently.
46
+
47
+ # if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
48
+ if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593":
49
+ revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4)
50
+ self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
51
+ full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
52
+ logging.info(f"Loaded AST from {ckpt_path}")
53
+ else:
54
+ self.config = ASTConfig()
55
+ self.config.num_labels = 527 # 2 by default, audioset has 527 labels
56
+ full_model = ASTForAudioClassification(self.config)
57
+ logging.info("Initialized AST from scratch with the AST AudioSet config")
58
+
59
+ was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt")
60
+
61
+ # feature extractor
62
+ self.ast = full_model.audio_spectrogram_transformer
63
+
64
+ if self.extract_features:
65
+ # assign `feat_type` (use default if not specified)
66
+ self.feat_type = "last_hidden_state" if feat_type is None else feat_type
67
+ # define adapters if needed
68
+ self.factorize_freq_time = factorize_freq_time
69
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
70
+ transf_enc_layer_kwargs = dict(
71
+ d_model=self.config.hidden_size,
72
+ nhead=self.config.num_attention_heads,
73
+ dim_feedforward=self.config.intermediate_size,
74
+ activation=torch.nn.GELU(),
75
+ batch_first=True,
76
+ dropout=self.config.attention_probs_dropout_prob,
77
+ layer_norm_eps=1e-6,
78
+ norm_first=True,
79
+ )
80
+ if factorize_freq_time:
81
+ self.feat_type = "last_hidden_state" # this feat_type supports factorization
82
+ # frequency aggreration
83
+ if agg_freq_module == "TransformerEncoderLayer":
84
+ self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
85
+ elif agg_freq_module == "AveragePooling":
86
+ self.freq_attn_agg = AveragePooling(
87
+ avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D"
88
+ )
89
+ # time aggreration
90
+ if agg_time_module == "TransformerEncoderLayer":
91
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
92
+ elif agg_time_module == "AveragePooling":
93
+ self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D")
94
+ elif "Identity" in agg_time_module:
95
+ self.temp_attn_agg = torch.nn.Identity()
96
+ # define a global aggregation layer (aggregarate over segments)
97
+ self.add_global_repr = add_global_repr
98
+ if add_global_repr:
99
+ if agg_segments_module == "TransformerEncoderLayer":
100
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
101
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
102
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
103
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
104
+ add_pos_emb=True,
105
+ pos_emb_drop=self.config.hidden_dropout_prob,
106
+ pos_max_len=pos_max_len,
107
+ **transf_enc_layer_kwargs,
108
+ )
109
+ elif agg_segments_module == "AveragePooling":
110
+ self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D")
111
+ else:
112
+ self.classifier = full_model.classifier
113
+
114
+ # AST.device fails with AttributeError. This is a workaround
115
+ self.device = full_model.device
116
+
117
+ # pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
118
+ self.patch_position_emb()
119
+
120
+ if was_pt_on_avclip:
121
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
122
+ # and keep only the state_dict of the feat extractor
123
+ check_if_file_exists_else_download(self.ckpt_path)
124
+ ckpt = torch.load(ckpt_path, map_location="cpu")
125
+ ckpt_weights = dict()
126
+ for k, v in ckpt["state_dict"].items():
127
+ if k.startswith(("module.a_encoder.", "a_encoder.")):
128
+ k = k.replace("module.", "").replace("a_encoder.", "")
129
+ ckpt_weights[k] = v
130
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
131
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
132
+ logging.warning(
133
+ f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n"
134
+ f"Missing keys ({len(_load_status.missing_keys)}): "
135
+ f"{_load_status.missing_keys}, \n"
136
+ f"Unexpected keys ({len(_load_status.unexpected_keys)}): "
137
+ f"{_load_status.unexpected_keys} \n"
138
+ f"temp_attn_agg are expected to be missing if ckpt was pt contrastively."
139
+ )
140
+ else:
141
+ logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.")
142
+
143
+ # print the number of parameters
144
+ logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
145
+
146
+ def forward(
147
+ self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs
148
+ ) -> torch.Tensor:
149
+ """
150
+ x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
151
+ ast_kwargs: additional arguments for the AST model
152
+ cont_mask: (B, S, T, F) where 0s are the values to be masked out
153
+ if `for_loop=True`, we use a for loop to extract features for each segment separately.
154
+ if `for_loop=False`, we extract features for all segments at once.
155
+ Using the for loop is slower but more memory efficient, while using all segments at once
156
+ is faster but more memory inefficient.
157
+ Using for loop allows to control the memory footprint by varying the number of videos in a
158
+ batch (batch size) rather than the number of segments in a video.
159
+ """
160
+ B, S, T, F = x.shape
161
+
162
+ if for_loop:
163
+ assert cont_mask is None, "cont_mask is not supported with for_loop=True"
164
+ orig_shape_s = (B, 1, T, F)
165
+ # NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
166
+ # (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
167
+ x = torch.cat(
168
+ [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1
169
+ )
170
+ else:
171
+ orig_shape = (B, S, T, F)
172
+ x = x.view(B * S, T, F)
173
+ if cont_mask is not None:
174
+ cont_mask = cont_mask.reshape(B * S, T, F)
175
+ # AST expects a tensor of shape (B*S, T, F).
176
+ x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
177
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
178
+ x = x.view(B, S, *x.shape[1:])
179
+ # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
180
+
181
+ global_x = None
182
+ if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
183
+ assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}"
184
+ global_x = self.global_attn_agg(x) # (B, D)
185
+
186
+ return x, global_x # x is (B, S, ...), global_x is (B, D) or None
187
+
188
+ def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
189
+ """x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out"""
190
+ # 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, <tokens>]
191
+ # x_mask is (B, T) where 0s are the values to be masked out
192
+ x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
193
+
194
+ if self.extract_features:
195
+ x = self.get_features_by_type(x)
196
+ if self.factorize_freq_time:
197
+ x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
198
+ if cont_mask is not None:
199
+ # duplicating the mask for the latent dimension (D) to be compatible with the next func
200
+ x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
201
+ x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
202
+ # again removing the latent
203
+ x_mask = x_mask[:, 0, :, :]
204
+ else:
205
+ x_mask = None
206
+ x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
207
+ x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
208
+ else:
209
+ x = x["pooler_output"]
210
+ x = self.classifier(x)
211
+ return x
212
+
213
+ def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
214
+ if self.feat_type == "pooler_output":
215
+ return x["pooler_output"] # (B, D)
216
+ elif self.feat_type == "CLS":
217
+ return x["last_hidden_state"][:, 0, :] # (B, D)
218
+ elif self.feat_type == "last_hidden_state":
219
+ return x["last_hidden_state"] # (B, 2+T, D)
220
+ elif self.feat_type == "last_hidden_state_no_AUX":
221
+ return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens
222
+ else:
223
+ raise ValueError(f"Unknown feature type: {self.feat_type}")
224
+
225
+ def restore_freq_temp_dims(self, feats, orig_shape: tuple):
226
+ """
227
+ feats are of shape (B*S, T, D)
228
+ where T = 2 + f * t (if feat_type == 'last_hidden_state')
229
+ where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
230
+ Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
231
+ From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
232
+ `feats.transpose(1, 2).view(B*S, D, f, t)`
233
+
234
+ (Similar function is defined in for RGB features in `motionformer.py`)
235
+ """
236
+ B, S, T, F = orig_shape
237
+ D = self.config.hidden_size
238
+
239
+ # num patches in each dimension
240
+ f, t = self.ast.embeddings.get_shape(self.config)
241
+
242
+ if self.feat_type == "last_hidden_state":
243
+ feats = feats[:, 2:, :] # removing CLS and distill tokens
244
+
245
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
246
+ feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
247
+
248
+ return feats
249
+
250
+ def patch_position_emb(self):
251
+ if self.max_spec_t is not None:
252
+ self.config.max_length = self.max_spec_t
253
+ f, t = self.ast.embeddings.get_shape(self.config)
254
+ shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens
255
+ self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
256
+
257
+ def to(self, device):
258
+ """AST.device fails with AttributeError. This is a workaround."""
259
+ self.device = torch.device(device)
260
+ return super().to(device)
261
+
262
+
263
+ class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
264
+ """This layer is used to aggregate the features along the frequency axis.
265
+ It follows the same logic as spatio-temporal aggregation in visual feature extractor.
266
+ Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`"""
267
+
268
+ def __init__(self, *args, **kwargs):
269
+ super().__init__(*args, **kwargs)
270
+
271
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
272
+ """x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out"""
273
+ BS, D, f, t = x.shape
274
+
275
+ # time as a batch dimension
276
+ x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
277
+ x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
278
+ # similar to mask
279
+ if x_mask is not None:
280
+ x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
281
+ x_mask = x_mask.reshape(BS * t, f)
282
+
283
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
284
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
285
+
286
+ # reshape back to (B*S, t, D)
287
+ x = x.view(BS, t, D)
288
+
289
+ return x # (B*S, t, D)
HunyuanVideo-Foley/hunyuanvideo_foley/models/synchformer/compute_desync_score.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+ import torchvision
8
+ from omegaconf import OmegaConf
9
+
10
+ import data_transforms
11
+ from .synchformer import Synchformer
12
+ from .data_transforms import make_class_grid, quantize_offset
13
+ from .utils import check_if_file_exists_else_download, which_ffmpeg
14
+
15
+
16
+ def prepare_inputs(batch, device):
17
+ aud = batch["audio"].to(device)
18
+ vid = batch["video"].to(device)
19
+
20
+ return aud, vid
21
+
22
+
23
+ def get_test_transforms():
24
+ ts = [
25
+ data_transforms.EqualifyFromRight(),
26
+ data_transforms.RGBSpatialCrop(input_size=224, is_random=False),
27
+ data_transforms.TemporalCropAndOffset(
28
+ crop_len_sec=5,
29
+ max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
30
+ max_wiggle_sec=0.0,
31
+ do_offset=True,
32
+ offset_type="grid",
33
+ prob_oos="null",
34
+ grid_size=21,
35
+ segment_size_vframes=16,
36
+ n_segments=14,
37
+ step_size_seg=0.5,
38
+ vfps=25,
39
+ ),
40
+ data_transforms.GenerateMultipleSegments(
41
+ segment_size_vframes=16,
42
+ n_segments=14,
43
+ is_start_random=False,
44
+ step_size_seg=0.5,
45
+ ),
46
+ data_transforms.RGBToHalfToZeroOne(),
47
+ data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization
48
+ data_transforms.AudioMelSpectrogram(
49
+ sample_rate=16000,
50
+ win_length=400, # 25 ms * 16 kHz
51
+ hop_length=160, # 10 ms * 16 kHz
52
+ n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate)))
53
+ n_mels=128, # as in AST
54
+ ),
55
+ data_transforms.AudioLog(),
56
+ data_transforms.PadOrTruncate(max_spec_t=66),
57
+ data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet
58
+ data_transforms.PermuteStreams(
59
+ einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same
60
+ ),
61
+ ]
62
+ transforms = torchvision.transforms.Compose(ts)
63
+
64
+ return transforms
65
+
66
+
67
+ def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
68
+ orig_path = path
69
+ # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
70
+ rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW")
71
+ assert meta["video_fps"], f"No video fps for {orig_path}"
72
+ # (Ta) <- (Ca, Ta)
73
+ audio = audio.mean(dim=0)
74
+ # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
75
+ meta = {
76
+ "video": {"fps": [meta["video_fps"]]},
77
+ "audio": {"framerate": [meta["audio_fps"]]},
78
+ }
79
+ return rgb, audio, meta
80
+
81
+
82
+ def reencode_video(path, vfps=25, afps=16000, in_size=256):
83
+ assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated."
84
+ new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4"
85
+ new_path.parent.mkdir(exist_ok=True)
86
+ new_path = str(new_path)
87
+ cmd = f"{which_ffmpeg()}"
88
+ # no info/error printing
89
+ cmd += " -hide_banner -loglevel panic"
90
+ cmd += f" -y -i {path}"
91
+ # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
92
+ cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
93
+ cmd += f" -ar {afps}"
94
+ cmd += f" {new_path}"
95
+ subprocess.call(cmd.split())
96
+ cmd = f"{which_ffmpeg()}"
97
+ cmd += " -hide_banner -loglevel panic"
98
+ cmd += f" -y -i {new_path}"
99
+ cmd += f" -acodec pcm_s16le -ac 1"
100
+ cmd += f' {new_path.replace(".mp4", ".wav")}'
101
+ subprocess.call(cmd.split())
102
+ return new_path
103
+
104
+
105
+ def decode_single_video_prediction(off_logits, grid, item):
106
+ label = item["targets"]["offset_label"].item()
107
+ print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})")
108
+ print()
109
+ print("Prediction Results:")
110
+ off_probs = torch.softmax(off_logits, dim=-1)
111
+ k = min(off_probs.shape[-1], 5)
112
+ topk_logits, topk_preds = torch.topk(off_logits, k)
113
+ # remove batch dimension
114
+ assert len(topk_logits) == 1, "batch is larger than 1"
115
+ topk_logits = topk_logits[0]
116
+ topk_preds = topk_preds[0]
117
+ off_logits = off_logits[0]
118
+ off_probs = off_probs[0]
119
+ for target_hat in topk_preds:
120
+ print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
121
+ return off_probs
122
+
123
+
124
+ def main(args):
125
+ vfps = 25
126
+ afps = 16000
127
+ in_size = 256
128
+ # making the offset class grid similar to the one used in transforms,
129
+ # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
130
+ max_off_sec = 2
131
+ num_cls = 21
132
+
133
+ # checking if the provided video has the correct frame rates
134
+ print(f"Using video: {args.vid_path}")
135
+ v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec")
136
+ _, H, W, _ = v.shape
137
+ if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size:
138
+ print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ")
139
+ print(f'afps: {info["audio_fps"]} -> {afps};', end=" ")
140
+ print(f"{(H, W)} -> min(H, W)={in_size}")
141
+ args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size)
142
+ else:
143
+ print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')
144
+
145
+ device = torch.device(args.device)
146
+
147
+ # load visual and audio streams
148
+ # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
149
+ rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True)
150
+
151
+ # making an item (dict) to apply transformations
152
+ # NOTE: here is how it works:
153
+ # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
154
+ # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
155
+ # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
156
+ # start by `offset_sec` earlier than the rgb track.
157
+ # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
158
+ item = dict(
159
+ video=rgb,
160
+ audio=audio,
161
+ meta=meta,
162
+ path=args.vid_path,
163
+ split="test",
164
+ targets={
165
+ "v_start_i_sec": args.v_start_i_sec,
166
+ "offset_sec": args.offset_sec,
167
+ },
168
+ )
169
+
170
+ grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
171
+ if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)):
172
+ print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')
173
+
174
+ # applying the test-time transform
175
+ item = get_test_transforms()(item)
176
+
177
+ # prepare inputs for inference
178
+ batch = torch.utils.data.default_collate([item])
179
+ aud, vid = prepare_inputs(batch, device)
180
+
181
+ # TODO:
182
+ # sanity check: we will take the input to the `model` and recontruct make a video from it.
183
+ # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
184
+ # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec,
185
+ # vfps, afps)
186
+
187
+ # forward pass
188
+ with torch.set_grad_enabled(False):
189
+ with torch.autocast("cuda", enabled=True):
190
+ _, logits = synchformer(vid, aud)
191
+
192
+ # simply prints the results of the prediction
193
+ decode_single_video_prediction(logits, grid, item)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ parser = argparse.ArgumentParser()
198
+ parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx")
199
+ parser.add_argument("--vid_path", required=True, help="A path to .mp4 video")
200
+ parser.add_argument("--offset_sec", type=float, default=0.0)
201
+ parser.add_argument("--v_start_i_sec", type=float, default=0.0)
202
+ parser.add_argument("--device", default="cuda:0")
203
+ args = parser.parse_args()
204
+
205
+ synchformer = Synchformer().cuda().eval()
206
+ synchformer.load_state_dict(
207
+ torch.load(
208
+ os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
209
+ weights_only=True,
210
+ map_location="cpu",
211
+ )
212
+ )
213
+
214
+ main(args)