Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
cf2f35c
1
Parent(s):
bc4a00b
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gradio/certificate.pem +31 -0
- LICENSE.txt +21 -0
- accelerate_config/accelerate_config_machine_14B_multiple.yaml +19 -0
- accelerate_config/accelerate_config_machine_1B_multiple.yaml +15 -0
- app.py +718 -0
- audio_extractor.py +14 -0
- deepspeed_config/wan2.1/wan_civitai.yaml +39 -0
- deepspeed_config/zero2_offload_cpu.json +35 -0
- deepspeed_config/zero_stage2_config.json +35 -0
- deepspeed_config/zero_stage3_config.json +46 -0
- example_case/case-1/audio.wav +3 -0
- example_case/case-1/prompt.txt +1 -0
- example_case/case-1/reference.png +3 -0
- example_case/case-2/audio.wav +3 -0
- example_case/case-2/prompt.txt +1 -0
- example_case/case-2/reference.png +3 -0
- example_case/case-3/audio.wav +3 -0
- example_case/case-3/prompt.txt +1 -0
- example_case/case-3/reference.jpg +3 -0
- example_case/case-45/audio.wav +3 -0
- example_case/case-45/prompt.txt +1 -0
- example_case/case-45/reference.png +3 -0
- example_case/case-6/audio.wav +3 -0
- example_case/case-6/prompt.txt +1 -0
- example_case/case-6/reference.png +3 -0
- extract_audio_segment.py +146 -0
- lip_mask_extractor.py +70 -0
- requirements.txt +170 -0
- vocal_seperator.py +31 -0
- wan/__init__.py +3 -0
- wan/__pycache__/__init__.cpython-311.pyc +0 -0
- wan/configs/__init__.py +42 -0
- wan/configs/shared_config.py +19 -0
- wan/configs/wan_i2v_14B.py +35 -0
- wan/configs/wan_t2v_14B.py +29 -0
- wan/configs/wan_t2v_1_3B.py +29 -0
- wan/dataset/talking_video_dataset_fantasy.py +328 -0
- wan/dist/__init__.py +40 -0
- wan/dist/__pycache__/__init__.cpython-311.pyc +0 -0
- wan/dist/__pycache__/wan_xfuser.cpython-311.pyc +0 -0
- wan/dist/wan_xfuser.py +115 -0
- wan/distributed/__init__.py +0 -0
- wan/distributed/__pycache__/__init__.cpython-311.pyc +0 -0
- wan/distributed/__pycache__/fsdp.cpython-311.pyc +0 -0
- wan/distributed/fsdp.py +41 -0
- wan/distributed/xdit_context_parallel.py +192 -0
- wan/image2video.py +334 -0
- wan/models/__init__.py +0 -0
- wan/models/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Shuyuan Tu.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
accelerate_config/accelerate_config_machine_14B_multiple.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_config_file: path/StableAvatar/deepspeed_config/zero_stage2_config.json
|
| 5 |
+
deepspeed_multinode_launcher: standard
|
| 6 |
+
zero3_init_flag: False
|
| 7 |
+
distributed_type: DEEPSPEED
|
| 8 |
+
downcast_bf16: 'no'
|
| 9 |
+
enable_cpu_affinity: false
|
| 10 |
+
main_training_function: main
|
| 11 |
+
dynamo_backend: 'no'
|
| 12 |
+
num_machines: 8
|
| 13 |
+
num_processes: 64
|
| 14 |
+
rdzv_backend: static
|
| 15 |
+
same_network: true
|
| 16 |
+
tpu_env: []
|
| 17 |
+
tpu_use_cluster: false
|
| 18 |
+
tpu_use_sudo: false
|
| 19 |
+
use_cpu: false
|
accelerate_config/accelerate_config_machine_1B_multiple.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
main_training_function: main
|
| 7 |
+
dynamo_backend: 'no'
|
| 8 |
+
num_machines: 8
|
| 9 |
+
num_processes: 64
|
| 10 |
+
rdzv_backend: static
|
| 11 |
+
same_network: true
|
| 12 |
+
tpu_env: []
|
| 13 |
+
tpu_use_cluster: false
|
| 14 |
+
tpu_use_sudo: false
|
| 15 |
+
use_cpu: false
|
app.py
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import psutil
|
| 3 |
+
import argparse
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import os
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 7 |
+
from diffusers.utils import load_image
|
| 8 |
+
from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from wan.models.cache_utils import get_teacache_coefficients
|
| 11 |
+
from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel
|
| 12 |
+
from wan.models.wan_text_encoder import WanT5EncoderModel
|
| 13 |
+
from wan.models.wan_vae import AutoencoderKLWan
|
| 14 |
+
from wan.models.wan_image_encoder import CLIPModel
|
| 15 |
+
from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline
|
| 16 |
+
from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8
|
| 17 |
+
from wan.utils.utils import get_image_to_video_latent, save_videos_grid
|
| 18 |
+
import numpy as np
|
| 19 |
+
import librosa
|
| 20 |
+
import datetime
|
| 21 |
+
import random
|
| 22 |
+
import math
|
| 23 |
+
import subprocess
|
| 24 |
+
from moviepy.editor import VideoFileClip
|
| 25 |
+
from huggingface_hub import snapshot_download
|
| 26 |
+
import shutil
|
| 27 |
+
try:
|
| 28 |
+
from audio_separator.separator import Separator
|
| 29 |
+
except:
|
| 30 |
+
print("Unable to use vocal separation feature. Please install audio-separator[gpu].")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
device = "cuda"
|
| 35 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 36 |
+
dtype = torch.bfloat16
|
| 37 |
+
else:
|
| 38 |
+
dtype = torch.float16
|
| 39 |
+
else:
|
| 40 |
+
device = "cpu"
|
| 41 |
+
dtype = torch.float32
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def filter_kwargs(cls, kwargs):
|
| 45 |
+
import inspect
|
| 46 |
+
sig = inspect.signature(cls.__init__)
|
| 47 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 48 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 49 |
+
return filtered_kwargs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_transformer_model(model_version):
|
| 53 |
+
"""
|
| 54 |
+
根据选择的模型版本加载对应的transformer模型
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_version (str): 模型版本,"square" 或 "rec_vec"
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
WanTransformer3DFantasyModel: 加载的transformer模型
|
| 61 |
+
"""
|
| 62 |
+
global transformer3d
|
| 63 |
+
|
| 64 |
+
if model_version == "square":
|
| 65 |
+
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
|
| 66 |
+
elif model_version == "rec_vec":
|
| 67 |
+
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
|
| 68 |
+
else:
|
| 69 |
+
# 默认使用square版本
|
| 70 |
+
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
|
| 71 |
+
|
| 72 |
+
print(f"正在加载模型: {transformer_path}")
|
| 73 |
+
|
| 74 |
+
if os.path.exists(transformer_path):
|
| 75 |
+
state_dict = torch.load(transformer_path, map_location="cpu")
|
| 76 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 77 |
+
m, u = transformer3d.load_state_dict(state_dict, strict=False)
|
| 78 |
+
print(f"模型加载成功: {transformer_path}")
|
| 79 |
+
print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
|
| 80 |
+
return transformer3d
|
| 81 |
+
else:
|
| 82 |
+
print(f"错误:模型文件不存在: {transformer_path}")
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
REPO_ID = "FrancisRing/StableAvatar"
|
| 87 |
+
repo_root = snapshot_download(
|
| 88 |
+
repo_id=REPO_ID,
|
| 89 |
+
allow_patterns=[
|
| 90 |
+
"StableAvatar-1.3B/*",
|
| 91 |
+
"Wan2.1-Fun-V1.1-1.3B-InP/*",
|
| 92 |
+
"wav2vec2-base-960h/*",
|
| 93 |
+
"assets/**",
|
| 94 |
+
"Kim_Vocal_2.onnx",
|
| 95 |
+
],
|
| 96 |
+
)
|
| 97 |
+
pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
|
| 98 |
+
pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# 人声分离 onnx
|
| 102 |
+
audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
|
| 103 |
+
|
| 104 |
+
# model_path = "/datadrive/stableavatar/checkpoints"
|
| 105 |
+
# pretrained_model_name_or_path = f"{model_path}/Wan2.1-Fun-V1.1-1.3B-InP"
|
| 106 |
+
# pretrained_wav2vec_path = f"{model_path}/wav2vec2-base-960h"
|
| 107 |
+
# transformer_path = f"{model_path}/StableAvatar-1.3B/transformer3d-square.pt"
|
| 108 |
+
config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
|
| 109 |
+
sampler_name = "Flow"
|
| 110 |
+
clip_sample_n_frames = 81
|
| 111 |
+
tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), )
|
| 112 |
+
text_encoder = WanT5EncoderModel.from_pretrained(
|
| 113 |
+
os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
|
| 114 |
+
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
|
| 115 |
+
low_cpu_mem_usage=True,
|
| 116 |
+
torch_dtype=dtype,
|
| 117 |
+
)
|
| 118 |
+
text_encoder = text_encoder.eval()
|
| 119 |
+
vae = AutoencoderKLWan.from_pretrained(
|
| 120 |
+
os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
|
| 121 |
+
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
|
| 122 |
+
)
|
| 123 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
|
| 124 |
+
wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
|
| 125 |
+
clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), )
|
| 126 |
+
clip_image_encoder = clip_image_encoder.eval()
|
| 127 |
+
transformer3d = WanTransformer3DFantasyModel.from_pretrained(
|
| 128 |
+
os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
| 129 |
+
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
| 130 |
+
low_cpu_mem_usage=False,
|
| 131 |
+
torch_dtype=dtype,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# 默认加载square版本模型
|
| 135 |
+
load_transformer_model("square")
|
| 136 |
+
Choosen_Scheduler = scheduler_dict = {
|
| 137 |
+
"Flow": FlowMatchEulerDiscreteScheduler,
|
| 138 |
+
}[sampler_name]
|
| 139 |
+
scheduler = Choosen_Scheduler(
|
| 140 |
+
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
|
| 141 |
+
)
|
| 142 |
+
pipeline = WanI2VTalkingInferenceLongPipeline(
|
| 143 |
+
tokenizer=tokenizer,
|
| 144 |
+
text_encoder=text_encoder,
|
| 145 |
+
vae=vae,
|
| 146 |
+
transformer=transformer3d,
|
| 147 |
+
clip_image_encoder=clip_image_encoder,
|
| 148 |
+
scheduler=scheduler,
|
| 149 |
+
wav2vec_processor=wav2vec_processor,
|
| 150 |
+
wav2vec=wav2vec,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def generate(
|
| 155 |
+
GPU_memory_mode,
|
| 156 |
+
teacache_threshold,
|
| 157 |
+
num_skip_start_steps,
|
| 158 |
+
image_path,
|
| 159 |
+
audio_path,
|
| 160 |
+
prompt,
|
| 161 |
+
negative_prompt,
|
| 162 |
+
width,
|
| 163 |
+
height,
|
| 164 |
+
guidance_scale,
|
| 165 |
+
num_inference_steps,
|
| 166 |
+
text_guide_scale,
|
| 167 |
+
audio_guide_scale,
|
| 168 |
+
motion_frame,
|
| 169 |
+
fps,
|
| 170 |
+
overlap_window_length,
|
| 171 |
+
seed_param,
|
| 172 |
+
overlapping_weight_scheme,
|
| 173 |
+
progress=gr.Progress(track_tqdm=True),
|
| 174 |
+
):
|
| 175 |
+
global pipeline, transformer3d
|
| 176 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 177 |
+
if seed_param<0:
|
| 178 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 179 |
+
else:
|
| 180 |
+
seed = seed_param
|
| 181 |
+
|
| 182 |
+
if GPU_memory_mode == "sequential_cpu_offload":
|
| 183 |
+
replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
|
| 184 |
+
transformer3d.freqs = transformer3d.freqs.to(device=device)
|
| 185 |
+
pipeline.enable_sequential_cpu_offload(device=device)
|
| 186 |
+
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 187 |
+
convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
|
| 188 |
+
convert_weight_dtype_wrapper(transformer3d, dtype)
|
| 189 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 190 |
+
elif GPU_memory_mode == "model_cpu_offload":
|
| 191 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 192 |
+
else:
|
| 193 |
+
pipeline.to(device=device)
|
| 194 |
+
|
| 195 |
+
if teacache_threshold > 0:
|
| 196 |
+
coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
|
| 197 |
+
pipeline.transformer.enable_teacache(
|
| 198 |
+
coefficients,
|
| 199 |
+
num_inference_steps,
|
| 200 |
+
teacache_threshold,
|
| 201 |
+
num_skip_start_steps=num_skip_start_steps,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
|
| 206 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
|
| 207 |
+
sr = 16000
|
| 208 |
+
vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
|
| 209 |
+
sample = pipeline(
|
| 210 |
+
prompt,
|
| 211 |
+
num_frames=video_length,
|
| 212 |
+
negative_prompt=negative_prompt,
|
| 213 |
+
width=width,
|
| 214 |
+
height=height,
|
| 215 |
+
guidance_scale=guidance_scale,
|
| 216 |
+
generator=torch.Generator().manual_seed(seed),
|
| 217 |
+
num_inference_steps=num_inference_steps,
|
| 218 |
+
video=input_video,
|
| 219 |
+
mask_video=input_video_mask,
|
| 220 |
+
clip_image=clip_image,
|
| 221 |
+
text_guide_scale=text_guide_scale,
|
| 222 |
+
audio_guide_scale=audio_guide_scale,
|
| 223 |
+
vocal_input_values=vocal_input,
|
| 224 |
+
motion_frame=motion_frame,
|
| 225 |
+
fps=fps,
|
| 226 |
+
sr=sr,
|
| 227 |
+
cond_file_path=image_path,
|
| 228 |
+
overlap_window_length=overlap_window_length,
|
| 229 |
+
seed=seed,
|
| 230 |
+
overlapping_weight_scheme=overlapping_weight_scheme,
|
| 231 |
+
).videos
|
| 232 |
+
os.makedirs("outputs", exist_ok=True)
|
| 233 |
+
video_path = os.path.join("outputs", f"{timestamp}.mp4")
|
| 234 |
+
save_videos_grid(sample, video_path, fps=fps)
|
| 235 |
+
output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
|
| 236 |
+
subprocess.run([
|
| 237 |
+
"ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
|
| 238 |
+
"-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
|
| 239 |
+
output_video_with_audio
|
| 240 |
+
], check=True)
|
| 241 |
+
|
| 242 |
+
return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4 / 已生成outputs/{timestamp}.mp4"
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def exchange_width_height(width, height):
|
| 246 |
+
return height, width, "✅ Width and Height Swapped / 宽高交换完毕"
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def adjust_width_height(image):
|
| 250 |
+
image = load_image(image)
|
| 251 |
+
width, height = image.size
|
| 252 |
+
original_area = width * height
|
| 253 |
+
default_area = 512*512
|
| 254 |
+
ratio = math.sqrt(original_area / default_area)
|
| 255 |
+
width = width / ratio // 16 * 16
|
| 256 |
+
height = height / ratio // 16 * 16
|
| 257 |
+
return int(width), int(height), "✅ Adjusted Size Based on Image / 根据图片调整宽高"
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def audio_extractor(video_path):
|
| 261 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 262 |
+
os.makedirs("outputs", exist_ok=True) # 确保目录存在
|
| 263 |
+
out_wav = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav"))
|
| 264 |
+
video = VideoFileClip(video_path)
|
| 265 |
+
audio = video.audio
|
| 266 |
+
audio.write_audiofile(out_wav, codec="pcm_s16le")
|
| 267 |
+
return out_wav, f"Generated {out_wav} / 已生成 {out_wav}", out_wav # ← 第3个返回给 gr.File
|
| 268 |
+
|
| 269 |
+
def vocal_separation(audio_path):
|
| 270 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 271 |
+
os.makedirs("outputs", exist_ok=True)
|
| 272 |
+
# audio_separator_model_file = "checkpoints/Kim_Vocal_2.onnx"
|
| 273 |
+
audio_separator = Separator(
|
| 274 |
+
output_dir=os.path.abspath(os.path.join("outputs", timestamp)),
|
| 275 |
+
output_single_stem="vocals",
|
| 276 |
+
model_file_dir=os.path.dirname(audio_separator_model_file),
|
| 277 |
+
)
|
| 278 |
+
audio_separator.load_model(os.path.basename(audio_separator_model_file))
|
| 279 |
+
assert audio_separator.model_instance is not None, "Fail to load audio separate model."
|
| 280 |
+
outputs = audio_separator.separate(audio_path)
|
| 281 |
+
vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0])
|
| 282 |
+
destination_file = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav"))
|
| 283 |
+
shutil.copy(vocal_audio_file, destination_file)
|
| 284 |
+
os.remove(vocal_audio_file)
|
| 285 |
+
return destination_file, f"Generated {destination_file} / 已生成 {destination_file}", destination_file
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def update_language(language):
|
| 289 |
+
if language == "English":
|
| 290 |
+
return {
|
| 291 |
+
GPU_memory_mode: gr.Dropdown(label="GPU Memory Mode", info="Normal uses 25G VRAM, model_cpu_offload uses 13G VRAM"),
|
| 292 |
+
teacache_threshold: gr.Slider(label="TeaCache Threshold", info="Recommended 0.1, 0 disables TeaCache acceleration"),
|
| 293 |
+
num_skip_start_steps: gr.Slider(label="Skip Start Steps", info="Recommended 5"),
|
| 294 |
+
model_version: gr.Dropdown(label="Model Version", choices=["square", "rec_vec"], value="square"),
|
| 295 |
+
image_path: gr.Image(label="Upload Image"),
|
| 296 |
+
audio_path: gr.Audio(label="Upload Audio"),
|
| 297 |
+
prompt: gr.Textbox(label="Prompt"),
|
| 298 |
+
negative_prompt: gr.Textbox(label="Negative Prompt"),
|
| 299 |
+
generate_button: gr.Button("🎬 Start Generation"),
|
| 300 |
+
width: gr.Slider(label="Width"),
|
| 301 |
+
height: gr.Slider(label="Height"),
|
| 302 |
+
exchange_button: gr.Button("🔄 Swap Width/Height"),
|
| 303 |
+
adjust_button: gr.Button("Adjust Size Based on Image"),
|
| 304 |
+
guidance_scale: gr.Slider(label="Guidance Scale"),
|
| 305 |
+
num_inference_steps: gr.Slider(label="Sampling Steps (Recommended 50)"),
|
| 306 |
+
text_guide_scale: gr.Slider(label="Text Guidance Scale"),
|
| 307 |
+
audio_guide_scale: gr.Slider(label="Audio Guidance Scale"),
|
| 308 |
+
motion_frame: gr.Slider(label="Motion Frame"),
|
| 309 |
+
fps: gr.Slider(label="FPS"),
|
| 310 |
+
overlap_window_length: gr.Slider(label="Overlap Window Length"),
|
| 311 |
+
seed_param: gr.Number(label="Seed (positive integer, -1 for random)"),
|
| 312 |
+
overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"),
|
| 313 |
+
info: gr.Textbox(label="Status"),
|
| 314 |
+
video_output: gr.Video(label="Generated Result"),
|
| 315 |
+
seed_output: gr.Textbox(label="Seed"),
|
| 316 |
+
video_path: gr.Video(label="Upload Video"),
|
| 317 |
+
extractor_button: gr.Button("🎬 Start Extraction"),
|
| 318 |
+
info2: gr.Textbox(label="Status"),
|
| 319 |
+
audio_output: gr.Audio(label="Generated Result"),
|
| 320 |
+
audio_path3: gr.Audio(label="Upload Audio"),
|
| 321 |
+
separation_button: gr.Button("🎬 Start Separation"),
|
| 322 |
+
info3: gr.Textbox(label="Status"),
|
| 323 |
+
audio_output3: gr.Audio(label="Generated Result"),
|
| 324 |
+
example_title: gr.Markdown(value="### Select the following example cases for testing:"),
|
| 325 |
+
example1_label: gr.Markdown(value="**Example 1**"),
|
| 326 |
+
example2_label: gr.Markdown(value="**Example 2**"),
|
| 327 |
+
example3_label: gr.Markdown(value="**Example 3**"),
|
| 328 |
+
example4_label: gr.Markdown(value="**Example 4**"),
|
| 329 |
+
example5_label: gr.Markdown(value="**Example 5**"),
|
| 330 |
+
example1_btn: gr.Button("🚀 Use Example 1", variant="secondary"),
|
| 331 |
+
example2_btn: gr.Button("🚀 Use Example 2", variant="secondary"),
|
| 332 |
+
example3_btn: gr.Button("🚀 Use Example 3", variant="secondary"),
|
| 333 |
+
example4_btn: gr.Button("🚀 Use Example 4", variant="secondary"),
|
| 334 |
+
example5_btn: gr.Button("🚀 Use Example 5", variant="secondary"),
|
| 335 |
+
parameter_settings_title: gr.Accordion(label="Parameter Settings", open=True),
|
| 336 |
+
example_cases_title: gr.Accordion(label="Example Cases", open=True),
|
| 337 |
+
stableavatar_title: gr.TabItem(label="StableAvatar"),
|
| 338 |
+
audio_extraction_title: gr.TabItem(label="Audio Extraction"),
|
| 339 |
+
vocal_separation_title: gr.TabItem(label="Vocal Separation")
|
| 340 |
+
}
|
| 341 |
+
else:
|
| 342 |
+
return {
|
| 343 |
+
GPU_memory_mode: gr.Dropdown(label="显存模式", info="Normal占用25G显存,model_cpu_offload占用13G显存"),
|
| 344 |
+
teacache_threshold: gr.Slider(label="teacache threshold", info="推荐参数0.1,0为禁用teacache加速"),
|
| 345 |
+
num_skip_start_steps: gr.Slider(label="跳过开始步数", info="推荐参数5"),
|
| 346 |
+
model_version: gr.Dropdown(label="模型版本", choices=["square", "rec_vec"], value="square"),
|
| 347 |
+
image_path: gr.Image(label="上传图片"),
|
| 348 |
+
audio_path: gr.Audio(label="上传音频"),
|
| 349 |
+
prompt: gr.Textbox(label="提示词"),
|
| 350 |
+
negative_prompt: gr.Textbox(label="负面提示词"),
|
| 351 |
+
generate_button: gr.Button("🎬 开始生成"),
|
| 352 |
+
width: gr.Slider(label="宽度"),
|
| 353 |
+
height: gr.Slider(label="高度"),
|
| 354 |
+
exchange_button: gr.Button("🔄 交换宽高"),
|
| 355 |
+
adjust_button: gr.Button("根据图片调整宽高"),
|
| 356 |
+
guidance_scale: gr.Slider(label="guidance scale"),
|
| 357 |
+
num_inference_steps: gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50),
|
| 358 |
+
text_guide_scale: gr.Slider(label="text guidance scale"),
|
| 359 |
+
audio_guide_scale: gr.Slider(label="audio guidance scale"),
|
| 360 |
+
motion_frame: gr.Slider(label="motion frame"),
|
| 361 |
+
fps: gr.Slider(label="帧率"),
|
| 362 |
+
overlap_window_length: gr.Slider(label="overlap window length"),
|
| 363 |
+
seed_param: gr.Number(label="种子,请输入正整数,-1为随机"),
|
| 364 |
+
overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"),
|
| 365 |
+
info: gr.Textbox(label="提示信息"),
|
| 366 |
+
video_output: gr.Video(label="生成结果"),
|
| 367 |
+
seed_output: gr.Textbox(label="种子"),
|
| 368 |
+
video_path: gr.Video(label="上传视频"),
|
| 369 |
+
extractor_button: gr.Button("🎬 开始提取"),
|
| 370 |
+
info2: gr.Textbox(label="提示信息"),
|
| 371 |
+
audio_output: gr.Audio(label="生成结果"),
|
| 372 |
+
audio_path3: gr.Audio(label="上传音频"),
|
| 373 |
+
separation_button: gr.Button("🎬 开始分离"),
|
| 374 |
+
info3: gr.Textbox(label="提示信息"),
|
| 375 |
+
audio_output3: gr.Audio(label="生成结果"),
|
| 376 |
+
example_title: gr.Markdown(value="### 选择以下示例案例进行测试:"),
|
| 377 |
+
example1_label: gr.Markdown(value="**示例 1**"),
|
| 378 |
+
example2_label: gr.Markdown(value="**示例 2**"),
|
| 379 |
+
example3_label: gr.Markdown(value="**示例 3**"),
|
| 380 |
+
example4_label: gr.Markdown(value="**示例 4**"),
|
| 381 |
+
example5_label: gr.Markdown(value="**示例 5**"),
|
| 382 |
+
example1_btn: gr.Button("🚀 使用示例 1", variant="secondary"),
|
| 383 |
+
example2_btn: gr.Button("🚀 使用示例 2", variant="secondary"),
|
| 384 |
+
example3_btn: gr.Button("🚀 使用示例 3", variant="secondary"),
|
| 385 |
+
example4_btn: gr.Button("🚀 使用示例 4", variant="secondary"),
|
| 386 |
+
example5_btn: gr.Button("🚀 使用示例 5", variant="secondary"),
|
| 387 |
+
parameter_settings_title: gr.Accordion(label="参数设置", open=True),
|
| 388 |
+
example_cases_title: gr.Accordion(label="示例案例", open=True),
|
| 389 |
+
stableavatar_title: gr.TabItem(label="StableAvatar"),
|
| 390 |
+
audio_extraction_title: gr.TabItem(label="音频提取"),
|
| 391 |
+
vocal_separation_title: gr.TabItem(label="人声分离")
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
BANNER_HTML = """
|
| 395 |
+
<div class="hero">
|
| 396 |
+
<div class="brand">
|
| 397 |
+
<!-- 如有项目 logo,可放到仓库并换成你的地址;没有就删这一行 -->
|
| 398 |
+
<!-- <img src="https://raw.githubusercontent.com/Francis-Rings/StableAvatar/main/assets/logo.png" alt="StableAvatar Logo"> -->
|
| 399 |
+
<span class="brand-text">STABLEAVATAR</span>
|
| 400 |
+
</div>
|
| 401 |
+
<div class="titles">
|
| 402 |
+
<h1>StableAvatar</h1>
|
| 403 |
+
<div class="badges">
|
| 404 |
+
<a class="badge" href="https://arxiv.org/abs/2508.08248" target="_blank" rel="noopener">
|
| 405 |
+
<img src="https://img.shields.io/badge/arXiv-2508.08248-b31b1b">
|
| 406 |
+
</a>
|
| 407 |
+
<a class="badge" href="https://francis-rings.github.io/StableAvatar/" target="_blank" rel="noopener">
|
| 408 |
+
<img src="https://img.shields.io/badge/Webpage-Visit-2266ee">
|
| 409 |
+
</a>
|
| 410 |
+
<a class="badge" href="https://github.com/Francis-Rings/StableAvatar" target="_blank" rel="noopener">
|
| 411 |
+
<img src="https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white">
|
| 412 |
+
</a>
|
| 413 |
+
<a class="badge" href="https://www.youtube.com/watch?v=6lhvmbzvv3Y" target="_blank" rel="noopener">
|
| 414 |
+
<img src="https://img.shields.io/badge/YouTube-Demo-ff0000?logo=youtube&logoColor=white">
|
| 415 |
+
</a>
|
| 416 |
+
</div>
|
| 417 |
+
</div>
|
| 418 |
+
</div>
|
| 419 |
+
<hr class="divider">
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
BANNER_CSS = """
|
| 423 |
+
.hero{display:flex;align-items:center;gap:18px;padding:18px;border-radius:14px;
|
| 424 |
+
background:#111;color:#fff;margin-bottom:12px}
|
| 425 |
+
.brand-text{font-weight:800;letter-spacing:2px}
|
| 426 |
+
.brand img{height:46px}
|
| 427 |
+
.titles h1{font-size:28px;margin:0 0 6px 0}
|
| 428 |
+
.badges{display:flex;gap:10px;flex-wrap:wrap}
|
| 429 |
+
.badge img{height:22px}
|
| 430 |
+
.divider{border:0;border-top:1px solid rgba(255,255,255,0.18);margin:6px 0 18px}
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# with gr.Blocks(theme=gr.themes.Base()) as demo:
|
| 435 |
+
# gr.Markdown("""
|
| 436 |
+
# <div>
|
| 437 |
+
# <h2 style="font-size: 30px;text-align: center;">StableAvatar</h2>
|
| 438 |
+
# </div>
|
| 439 |
+
# """)
|
| 440 |
+
with gr.Blocks(theme=gr.themes.Base(), css=BANNER_CSS) as demo:
|
| 441 |
+
gr.HTML(BANNER_HTML)
|
| 442 |
+
|
| 443 |
+
language_radio = gr.Radio(
|
| 444 |
+
choices=["English", "中文"],
|
| 445 |
+
value="English",
|
| 446 |
+
label="Language / 语言"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
with gr.Accordion("Model Settings / 模型设置", open=False):
|
| 450 |
+
with gr.Row():
|
| 451 |
+
GPU_memory_mode = gr.Dropdown(
|
| 452 |
+
label = "显存模式",
|
| 453 |
+
info = "Normal占用25G显存,model_cpu_offload占用13G显存",
|
| 454 |
+
choices = ["Normal", "model_cpu_offload", "model_cpu_offloadand_qfloat8", "sequential_cpu_offload"],
|
| 455 |
+
value = "model_cpu_offload"
|
| 456 |
+
)
|
| 457 |
+
teacache_threshold = gr.Slider(label="teacache threshold", info = "推荐参数0.1,0为禁用teacache加速", minimum=0, maximum=1, step=0.01, value=0)
|
| 458 |
+
num_skip_start_steps = gr.Slider(label="跳过开始步数", info = "推荐参数5", minimum=0, maximum=100, step=1, value=5)
|
| 459 |
+
with gr.Row():
|
| 460 |
+
model_version = gr.Dropdown(
|
| 461 |
+
label = "模型版本",
|
| 462 |
+
choices = ["square","rec_vec"],
|
| 463 |
+
value = "square"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
stableavatar_title = gr.TabItem(label="StableAvatar")
|
| 467 |
+
with stableavatar_title:
|
| 468 |
+
with gr.Row():
|
| 469 |
+
with gr.Column():
|
| 470 |
+
with gr.Row():
|
| 471 |
+
image_path = gr.Image(label="上传图片", type="filepath", height=280)
|
| 472 |
+
audio_path = gr.Audio(label="上传音频", type="filepath")
|
| 473 |
+
prompt = gr.Textbox(label="提示词", value="")
|
| 474 |
+
negative_prompt = gr.Textbox(label="负面提示词", value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 475 |
+
generate_button = gr.Button("🎬 开始生成", variant='primary')
|
| 476 |
+
parameter_settings_title = gr.Accordion(label="参数设置", open=True)
|
| 477 |
+
with parameter_settings_title:
|
| 478 |
+
with gr.Row():
|
| 479 |
+
width = gr.Slider(label="宽度", minimum=256, maximum=2048, step=16, value=512)
|
| 480 |
+
height = gr.Slider(label="高度", minimum=256, maximum=2048, step=16, value=512)
|
| 481 |
+
with gr.Row():
|
| 482 |
+
exchange_button = gr.Button("🔄 交换宽高")
|
| 483 |
+
adjust_button = gr.Button("根据图片调整宽高")
|
| 484 |
+
with gr.Row():
|
| 485 |
+
guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=6.0)
|
| 486 |
+
num_inference_steps = gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50)
|
| 487 |
+
with gr.Row():
|
| 488 |
+
text_guide_scale = gr.Slider(label="text guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.0)
|
| 489 |
+
audio_guide_scale = gr.Slider(label="audio guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
|
| 490 |
+
with gr.Row():
|
| 491 |
+
motion_frame = gr.Slider(label="motion frame", minimum=1, maximum=50, step=1, value=25)
|
| 492 |
+
fps = gr.Slider(label="帧率", minimum=1, maximum=60, step=1, value=25)
|
| 493 |
+
with gr.Row():
|
| 494 |
+
overlap_window_length = gr.Slider(label="overlap window length", minimum=1, maximum=20, step=1, value=10)
|
| 495 |
+
seed_param = gr.Number(label="种子,请输入正整数,-1为随机", value=42)
|
| 496 |
+
with gr.Row():
|
| 497 |
+
overlapping_weight_scheme = gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform")
|
| 498 |
+
with gr.Column():
|
| 499 |
+
info = gr.Textbox(label="提示信息", interactive=False)
|
| 500 |
+
video_output = gr.Video(label="生成结果", interactive=False)
|
| 501 |
+
seed_output = gr.Textbox(label="种子")
|
| 502 |
+
|
| 503 |
+
# 示例案例部分移到StableAvatar标签页内部
|
| 504 |
+
example_cases_title = gr.Accordion(label="示例案例", open=True)
|
| 505 |
+
with example_cases_title:
|
| 506 |
+
example_title = gr.Markdown(value="### 选择以下示例案例进行测试:")
|
| 507 |
+
with gr.Row():
|
| 508 |
+
with gr.Column():
|
| 509 |
+
example1_label = gr.Markdown(value="**示例 1**")
|
| 510 |
+
example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 511 |
+
example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False)
|
| 512 |
+
example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm")
|
| 513 |
+
|
| 514 |
+
with gr.Column():
|
| 515 |
+
example2_label = gr.Markdown(value="**示例 2**")
|
| 516 |
+
example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 517 |
+
example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False)
|
| 518 |
+
example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm")
|
| 519 |
+
|
| 520 |
+
with gr.Column():
|
| 521 |
+
example3_label = gr.Markdown(value="**示例 3**")
|
| 522 |
+
example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 523 |
+
example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False)
|
| 524 |
+
example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm")
|
| 525 |
+
|
| 526 |
+
with gr.Column():
|
| 527 |
+
example4_label = gr.Markdown(value="**示例 4**")
|
| 528 |
+
example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 529 |
+
example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False)
|
| 530 |
+
example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm")
|
| 531 |
+
|
| 532 |
+
with gr.Column():
|
| 533 |
+
example5_label = gr.Markdown(value="**示例 5**")
|
| 534 |
+
example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False)
|
| 535 |
+
example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False)
|
| 536 |
+
example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm")
|
| 537 |
+
|
| 538 |
+
audio_extraction_title = gr.TabItem(label="音频提取")
|
| 539 |
+
with audio_extraction_title:
|
| 540 |
+
with gr.Row():
|
| 541 |
+
with gr.Column():
|
| 542 |
+
video_path = gr.Video(label="上传视频", height=500)
|
| 543 |
+
extractor_button = gr.Button("🎬 开始提取", variant='primary')
|
| 544 |
+
with gr.Column():
|
| 545 |
+
info2 = gr.Textbox(label="提示信息", interactive=False)
|
| 546 |
+
audio_output = gr.Audio(label="生成结果", interactive=False)
|
| 547 |
+
audio_file = gr.File(label="download audio file")
|
| 548 |
+
|
| 549 |
+
vocal_separation_title = gr.TabItem(label="人声分离")
|
| 550 |
+
with vocal_separation_title:
|
| 551 |
+
with gr.Row():
|
| 552 |
+
with gr.Column():
|
| 553 |
+
audio_path3 = gr.Audio(label="上传音频", type="filepath")
|
| 554 |
+
separation_button = gr.Button("🎬 开始分离", variant='primary')
|
| 555 |
+
with gr.Column():
|
| 556 |
+
info3 = gr.Textbox(label="提示信息", interactive=False)
|
| 557 |
+
audio_output3 = gr.Audio(label="生成结果", interactive=False)
|
| 558 |
+
audio_file3 = gr.File(label="download audio file")
|
| 559 |
+
|
| 560 |
+
# 示例案例部分移到末尾
|
| 561 |
+
# example_cases_title = gr.Accordion(label="示例案例", open=True)
|
| 562 |
+
# with example_cases_title:
|
| 563 |
+
# example_title = gr.Markdown(value="### 选择以下示例案例进行测试:")
|
| 564 |
+
# with gr.Row():
|
| 565 |
+
# with gr.Column():
|
| 566 |
+
# example1_label = gr.Markdown(value="**示例 1**")
|
| 567 |
+
# example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 568 |
+
# example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False)
|
| 569 |
+
# example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm")
|
| 570 |
+
|
| 571 |
+
# with gr.Column():
|
| 572 |
+
# example2_label = gr.Markdown(value="**示例 2**")
|
| 573 |
+
# example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 574 |
+
# example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False)
|
| 575 |
+
# example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm")
|
| 576 |
+
|
| 577 |
+
# with gr.Column():
|
| 578 |
+
# example3_label = gr.Markdown(value="**示例 3**")
|
| 579 |
+
# example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 580 |
+
# example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False)
|
| 581 |
+
# example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm")
|
| 582 |
+
|
| 583 |
+
# with gr.Column():
|
| 584 |
+
# example4_label = gr.Markdown(value="**示例 4**")
|
| 585 |
+
# example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False)
|
| 586 |
+
# example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False)
|
| 587 |
+
# example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm")
|
| 588 |
+
|
| 589 |
+
# with gr.Column():
|
| 590 |
+
# example5_label = gr.Markdown(value="**示例 5**")
|
| 591 |
+
# example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False)
|
| 592 |
+
# example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False)
|
| 593 |
+
# example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm")
|
| 594 |
+
|
| 595 |
+
all_components = [GPU_memory_mode, teacache_threshold, num_skip_start_steps, model_version, image_path, audio_path, prompt, negative_prompt, generate_button, width, height, exchange_button, adjust_button, guidance_scale, num_inference_steps, text_guide_scale, audio_guide_scale, motion_frame, fps, overlap_window_length, seed_param, overlapping_weight_scheme, info, video_output, seed_output, video_path, extractor_button, info2, audio_output, audio_path3, separation_button, info3, audio_output3, example_title, example1_label, example2_label, example3_label, example4_label, example1_btn, example2_btn, example3_btn, example4_btn, example5_label, example5_btn, parameter_settings_title, example_cases_title, stableavatar_title, audio_extraction_title, vocal_separation_title]
|
| 596 |
+
|
| 597 |
+
language_radio.change(
|
| 598 |
+
fn=update_language,
|
| 599 |
+
inputs=[language_radio],
|
| 600 |
+
outputs=all_components
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# 添加模型版本选择的事件处理
|
| 604 |
+
def on_model_version_change(model_version):
|
| 605 |
+
"""当模型版本改变时,重新加载对应的模型"""
|
| 606 |
+
result = load_transformer_model(model_version)
|
| 607 |
+
if result is not None:
|
| 608 |
+
return f"✅ 模型已切换到 {model_version} 版本"
|
| 609 |
+
else:
|
| 610 |
+
return f"❌ 模型切换失败,请检查文件是否存在"
|
| 611 |
+
|
| 612 |
+
model_version.change(
|
| 613 |
+
fn=on_model_version_change,
|
| 614 |
+
inputs=[model_version],
|
| 615 |
+
outputs=[info]
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
demo.load(fn=update_language, inputs=[language_radio], outputs=all_components)
|
| 619 |
+
# 添加示例案例按钮的事件处理
|
| 620 |
+
def load_example1():
|
| 621 |
+
try:
|
| 622 |
+
with open("example_case/case-1/prompt.txt", "r", encoding="utf-8") as f:
|
| 623 |
+
prompt_text = f.read().strip()
|
| 624 |
+
except:
|
| 625 |
+
prompt_text = ""
|
| 626 |
+
return "example_case/case-1/reference.png", "example_case/case-1/audio.wav", prompt_text
|
| 627 |
+
|
| 628 |
+
def load_example2():
|
| 629 |
+
try:
|
| 630 |
+
with open("example_case/case-2/prompt.txt", "r", encoding="utf-8") as f:
|
| 631 |
+
prompt_text = f.read().strip()
|
| 632 |
+
except:
|
| 633 |
+
prompt_text = ""
|
| 634 |
+
return "example_case/case-2/reference.png", "example_case/case-2/audio.wav", prompt_text
|
| 635 |
+
|
| 636 |
+
def load_example3():
|
| 637 |
+
try:
|
| 638 |
+
with open("example_case/case-6/prompt.txt", "r", encoding="utf-8") as f:
|
| 639 |
+
prompt_text = f.read().strip()
|
| 640 |
+
except:
|
| 641 |
+
prompt_text = ""
|
| 642 |
+
return "example_case/case-6/reference.png", "example_case/case-6/audio.wav", prompt_text
|
| 643 |
+
|
| 644 |
+
def load_example4():
|
| 645 |
+
try:
|
| 646 |
+
with open("example_case/case-45/prompt.txt", "r", encoding="utf-8") as f:
|
| 647 |
+
prompt_text = f.read().strip()
|
| 648 |
+
except:
|
| 649 |
+
prompt_text = ""
|
| 650 |
+
return "example_case/case-45/reference.png", "example_case/case-45/audio.wav", prompt_text
|
| 651 |
+
|
| 652 |
+
def load_example5():
|
| 653 |
+
try:
|
| 654 |
+
with open("example_case/case-3/prompt.txt", "r", encoding="utf-8") as f:
|
| 655 |
+
prompt_text = f.read().strip()
|
| 656 |
+
except:
|
| 657 |
+
prompt_text = ""
|
| 658 |
+
return "example_case/case-3/reference.jpg", "example_case/case-3/audio.wav", prompt_text
|
| 659 |
+
|
| 660 |
+
example1_btn.click(fn=load_example1, outputs=[image_path, audio_path, prompt])
|
| 661 |
+
example2_btn.click(fn=load_example2, outputs=[image_path, audio_path, prompt])
|
| 662 |
+
example3_btn.click(fn=load_example3, outputs=[image_path, audio_path, prompt])
|
| 663 |
+
example4_btn.click(fn=load_example4, outputs=[image_path, audio_path, prompt])
|
| 664 |
+
example5_btn.click(fn=load_example5, outputs=[image_path, audio_path, prompt])
|
| 665 |
+
gr.on(
|
| 666 |
+
triggers=[generate_button.click, prompt.submit, negative_prompt.submit],
|
| 667 |
+
fn = generate,
|
| 668 |
+
inputs = [
|
| 669 |
+
GPU_memory_mode,
|
| 670 |
+
teacache_threshold,
|
| 671 |
+
num_skip_start_steps,
|
| 672 |
+
image_path,
|
| 673 |
+
audio_path,
|
| 674 |
+
prompt,
|
| 675 |
+
negative_prompt,
|
| 676 |
+
width,
|
| 677 |
+
height,
|
| 678 |
+
guidance_scale,
|
| 679 |
+
num_inference_steps,
|
| 680 |
+
text_guide_scale,
|
| 681 |
+
audio_guide_scale,
|
| 682 |
+
motion_frame,
|
| 683 |
+
fps,
|
| 684 |
+
overlap_window_length,
|
| 685 |
+
seed_param,
|
| 686 |
+
overlapping_weight_scheme,
|
| 687 |
+
],
|
| 688 |
+
outputs = [video_output, seed_output, info]
|
| 689 |
+
)
|
| 690 |
+
exchange_button.click(
|
| 691 |
+
fn=exchange_width_height,
|
| 692 |
+
inputs=[width, height],
|
| 693 |
+
outputs=[width, height, info]
|
| 694 |
+
)
|
| 695 |
+
adjust_button.click(
|
| 696 |
+
fn=adjust_width_height,
|
| 697 |
+
inputs=[image_path],
|
| 698 |
+
outputs=[width, height, info]
|
| 699 |
+
)
|
| 700 |
+
extractor_button.click(
|
| 701 |
+
fn=audio_extractor,
|
| 702 |
+
inputs=[video_path],
|
| 703 |
+
outputs=[audio_output, info2, audio_file]
|
| 704 |
+
)
|
| 705 |
+
separation_button.click(
|
| 706 |
+
fn=vocal_separation,
|
| 707 |
+
inputs=[audio_path3],
|
| 708 |
+
outputs=[audio_output3, info3, audio_file3]
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
if __name__ == "__main__":
|
| 713 |
+
demo.launch(
|
| 714 |
+
server_name="0.0.0.0",
|
| 715 |
+
server_port=int(os.getenv("PORT", 7860)),
|
| 716 |
+
share=False,
|
| 717 |
+
inbrowser=False,
|
| 718 |
+
)
|
audio_extractor.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from moviepy.editor import VideoFileClip
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
parser.add_argument("--video_path", type=str)
|
| 8 |
+
parser.add_argument("--saved_audio_path", type=str)
|
| 9 |
+
args = parser.parse_args()
|
| 10 |
+
video_path = args.video_path
|
| 11 |
+
saved_audio_path = args.saved_audio_path
|
| 12 |
+
video = VideoFileClip(video_path)
|
| 13 |
+
audio = video.audio
|
| 14 |
+
audio.write_audiofile(saved_audio_path, codec='pcm_s16le')
|
deepspeed_config/wan2.1/wan_civitai.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_subpath: ./
|
| 5 |
+
dict_mapping:
|
| 6 |
+
in_dim: in_channels
|
| 7 |
+
dim: hidden_size
|
| 8 |
+
|
| 9 |
+
vae_kwargs:
|
| 10 |
+
vae_subpath: Wan2.1_VAE.pth
|
| 11 |
+
temporal_compression_ratio: 4
|
| 12 |
+
spatial_compression_ratio: 8
|
| 13 |
+
|
| 14 |
+
text_encoder_kwargs:
|
| 15 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 16 |
+
tokenizer_subpath: google/umt5-xxl
|
| 17 |
+
text_length: 512
|
| 18 |
+
vocab: 256384
|
| 19 |
+
dim: 4096
|
| 20 |
+
dim_attn: 4096
|
| 21 |
+
dim_ffn: 10240
|
| 22 |
+
num_heads: 64
|
| 23 |
+
num_layers: 24
|
| 24 |
+
num_buckets: 32
|
| 25 |
+
shared_pos: False
|
| 26 |
+
dropout: 0.0
|
| 27 |
+
|
| 28 |
+
scheduler_kwargs:
|
| 29 |
+
scheduler_subpath: null
|
| 30 |
+
num_train_timesteps: 1000
|
| 31 |
+
shift: 5.0
|
| 32 |
+
use_dynamic_shifting: false
|
| 33 |
+
base_shift: 0.5
|
| 34 |
+
max_shift: 1.15
|
| 35 |
+
base_image_seq_len: 256
|
| 36 |
+
max_image_seq_len: 4096
|
| 37 |
+
|
| 38 |
+
image_encoder_kwargs:
|
| 39 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
deepspeed_config/zero2_offload_cpu.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": "auto"
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 6 |
+
"train_batch_size": "auto",
|
| 7 |
+
"gradient_accumulation_steps": "auto",
|
| 8 |
+
"zero_optimization": {
|
| 9 |
+
"stage": 2,
|
| 10 |
+
"overlap_comm": true,
|
| 11 |
+
"contiguous_gradients": true,
|
| 12 |
+
"sub_group_size": 1e9,
|
| 13 |
+
"offload_optimizer": {
|
| 14 |
+
"device": "cpu",
|
| 15 |
+
"pin_memory": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
|
| 19 |
+
"optimizer": {
|
| 20 |
+
"type": "AdamW",
|
| 21 |
+
"params": {
|
| 22 |
+
"lr": 5e-5,
|
| 23 |
+
"betas": [0.9, 0.95],
|
| 24 |
+
"weight_decay": 0.01
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"scheduler": {
|
| 28 |
+
"type": "WarmupDecayLR",
|
| 29 |
+
"params": {
|
| 30 |
+
"warmup_min_lr": 1e-6,
|
| 31 |
+
"warmup_max_lr": 5e-5,
|
| 32 |
+
"total_num_steps": 10000
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
}
|
deepspeed_config/zero_stage2_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 6 |
+
"train_batch_size": 64,
|
| 7 |
+
"gradient_clipping": 1.0,
|
| 8 |
+
"gradient_accumulation_steps": 1,
|
| 9 |
+
"dump_state": true,
|
| 10 |
+
"zero_optimization": {
|
| 11 |
+
"stage": 2,
|
| 12 |
+
"allgather_partitions": true,
|
| 13 |
+
"allgather_bucket_size": 2e8,
|
| 14 |
+
"overlap_comm": true,
|
| 15 |
+
"reduce_scatter": true,
|
| 16 |
+
"reduce_bucket_size": 1e8,
|
| 17 |
+
"contiguous_gradients": true
|
| 18 |
+
},
|
| 19 |
+
"optimizer": {
|
| 20 |
+
"type": "AdamW",
|
| 21 |
+
"params": {
|
| 22 |
+
"lr": 1e-4,
|
| 23 |
+
"betas": [0.9, 0.999],
|
| 24 |
+
"weight_decay": 3e-2
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"scheduler": {
|
| 28 |
+
"type": "WarmupLR",
|
| 29 |
+
"params": {
|
| 30 |
+
"warmup_min_lr": 1e-7,
|
| 31 |
+
"warmup_max_lr": 1e-4,
|
| 32 |
+
"warmup_num_steps": 100
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
}
|
deepspeed_config/zero_stage3_config.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"optimizer": {
|
| 6 |
+
"type": "AdamW",
|
| 7 |
+
"params": {
|
| 8 |
+
"lr": 2e-5,
|
| 9 |
+
"betas": [0.9, 0.999],
|
| 10 |
+
"weight_decay": 3e-2
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"scheduler": {
|
| 14 |
+
"type": "WarmupLR",
|
| 15 |
+
"params": {
|
| 16 |
+
"warmup_min_lr": 1e-7,
|
| 17 |
+
"warmup_max_lr": 2e-5,
|
| 18 |
+
"warmup_num_steps": 6400
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 22 |
+
"gradient_accumulation_steps": 1,
|
| 23 |
+
"train_batch_size": 64,
|
| 24 |
+
"gradient_clipping": 1.0,
|
| 25 |
+
"steps_per_print": 2000,
|
| 26 |
+
"wall_clock_breakdown": false,
|
| 27 |
+
"zero_optimization": {
|
| 28 |
+
"stage": 3,
|
| 29 |
+
"overlap_comm": true,
|
| 30 |
+
"contiguous_gradients": true,
|
| 31 |
+
"reduce_bucket_size": 5e8,
|
| 32 |
+
"sub_group_size": 1e9,
|
| 33 |
+
"stage3_max_live_parameters": 1e9,
|
| 34 |
+
"stage3_max_reuse_distance": 1e9,
|
| 35 |
+
"stage3_gather_16bit_weights_on_model_save": "auto",
|
| 36 |
+
"offload_optimizer": {
|
| 37 |
+
"device": "cpu",
|
| 38 |
+
"pin_memory": true
|
| 39 |
+
},
|
| 40 |
+
"offload_param": {
|
| 41 |
+
"device": "cpu",
|
| 42 |
+
"pin_memory": true
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
example_case/case-1/audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d12a8745971f1472c1ac5b3e3e5349163be7555b187ef3ad3cc4718393174458
|
| 3 |
+
size 17645370
|
example_case/case-1/prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Front-facing head-and-shoulders close-up of a middle-aged woman with short light brown hair, pearl earrings, and a blue blazer under soft studio lighting – She delivers a clear, confident speech with precise lip movements, steady gaze toward the camera, subtle eyebrow emphasis, slight nods, and occasional blinks while maintaining composed posture – Blurred civic architecture in the background resembling a government building, shallow depth of field, static camera.
|
example_case/case-1/reference.png
ADDED
|
|
Git LFS Details
|
example_case/case-2/audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb004cdfc7ba33e44c4128c555f43bbfdd88049a8937b1c5585db56efd59da15
|
| 3 |
+
size 2568018
|
example_case/case-2/prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Front-facing head-and-shoulders close-up of a middle-aged man with a shaved head, thin-rim glasses, and a striped shirt under soft warm lighting – He speaks clearly and thoughtfully with precise lip-sync, subtle eyebrow movement, slight nods, and occasional blinks while maintaining a steady posture – Indoor studio with blurred shutters and two warm pendant lights, shallow depth of field, and a static camera.
|
example_case/case-2/reference.png
ADDED
|
|
Git LFS Details
|
example_case/case-3/audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2162d8ca2e9ff692c132683c9f197cc73c84a6bb6cd8a3ed5aeefbc4711ad87
|
| 3 |
+
size 168014
|
example_case/case-3/prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Front-facing head-and-shoulders close-up of an adult woman with wavy dark brown hair and silver hoop earrings under soft warm lighting – She sings “there once was a ship that put to sea, the name of the ship was the Billy” with precise lip-sync, steady tempo, subtle head sway, gentle eyebrow lifts, and occasional blinks while maintaining a composed posture – Indoor studio with a softly blurred background and warm bokeh, shallow depth of field, static camera.
|
example_case/case-3/reference.jpg
ADDED
|
|
Git LFS Details
|
example_case/case-45/audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd5db4a3d4a970a51729aff7001ac34ffb13c486c62c7e44759023403121db66
|
| 3 |
+
size 3076494
|
example_case/case-45/prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Front-facing medium close-up of a young woman with long silver hair, elf-like ears, a cozy oversized light blue scarf, and a white outfit under soft daylight – She sings a sweet, lighthearted melody with precise lip-sync, a gentle smile, relaxed breathing, subtle head sway, and natural blinks while maintaining a warm and calm demeanor – Cozy indoor room with soft light, bed and curtain in the background, shallow depth of field, static camera.
|
example_case/case-45/reference.png
ADDED
|
|
Git LFS Details
|
example_case/case-6/audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4a1460a58d3a7662cb17494e99edc590bbab2254c80fbbd8d5c0ce327645c39
|
| 3 |
+
size 5821278
|
example_case/case-6/prompt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Front-facing medium close-up of a young woman with shoulder-length dark hair, wearing a white top and small hoop earrings, a studio microphone visible in the lower left under soft daylight – She sings smoothly with precise lip-sync, relaxed breathing, gentle head sway, subtle eyebrow emphasis, and natural blinks while maintaining a calm posture – Minimal indoor setting with a light gray wall and decorative molding, diagonal light and soft shadows, shallow depth of field, static camera.
|
example_case/case-6/reference.png
ADDED
|
|
Git LFS Details
|
extract_audio_segment.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
音频文件转换和片段提取工具
|
| 4 |
+
将MP3文件转换为WAV格式,并提取指定时间段的音频片段
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import subprocess
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def convert_mp3_to_wav_and_extract(input_file, start_time, end_time, output_dir=None):
|
| 12 |
+
"""
|
| 13 |
+
将MP3文件转换为WAV格式,并提取指定时间段的音频片段
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
input_file (str): 输入的MP3文件路径
|
| 17 |
+
start_time (float): 开始时间(秒)
|
| 18 |
+
end_time (float): 结束时间(秒)
|
| 19 |
+
output_dir (str): 输出目录,如果为None则使用输入文件所在目录
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
bool: 操作是否成功
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
# 检查输入文件是否存在
|
| 26 |
+
if not os.path.exists(input_file):
|
| 27 |
+
print(f"❌ 错误:输入文件不存在: {input_file}")
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
# 设置输出目录
|
| 31 |
+
if output_dir is None:
|
| 32 |
+
output_dir = os.path.dirname(input_file)
|
| 33 |
+
|
| 34 |
+
# 确保输出目录存在
|
| 35 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# 生成输出文件名
|
| 38 |
+
input_name = Path(input_file).stem
|
| 39 |
+
output_wav = os.path.join(output_dir, f"{input_name}.wav")
|
| 40 |
+
output_segment = os.path.join(output_dir, f"{input_name}_segment_{start_time}s_to_{end_time}s.wav")
|
| 41 |
+
|
| 42 |
+
print(f"🎵 开始处理音频文件: {input_file}")
|
| 43 |
+
print(f"📁 输出目录: {output_dir}")
|
| 44 |
+
|
| 45 |
+
# 步骤1:将MP3转换为WAV格式
|
| 46 |
+
print(f"\n🔄 步骤1: 将MP3转换为WAV格式")
|
| 47 |
+
convert_cmd = [
|
| 48 |
+
'ffmpeg',
|
| 49 |
+
'-y', # 覆盖输出文件
|
| 50 |
+
'-i', input_file, # 输入文件
|
| 51 |
+
'-ar', '16000', # 采样率16kHz
|
| 52 |
+
'-ac', '1', # 单声道
|
| 53 |
+
'-c:a', 'pcm_s16le', # 16位PCM编码
|
| 54 |
+
output_wav # 输出文件
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
print(f"执行命令: {' '.join(convert_cmd)}")
|
| 58 |
+
result = subprocess.run(convert_cmd, capture_output=True, text=True)
|
| 59 |
+
|
| 60 |
+
if result.returncode != 0:
|
| 61 |
+
print(f"❌ MP3转WAV失败: {result.stderr}")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
print(f"✅ MP3转WAV成功: {output_wav}")
|
| 65 |
+
|
| 66 |
+
# 步骤2:提取音频片段
|
| 67 |
+
print(f"\n🔄 步骤2: 提取音频片段 ({start_time}s - {end_time}s)")
|
| 68 |
+
duration = end_time - start_time
|
| 69 |
+
|
| 70 |
+
extract_cmd = [
|
| 71 |
+
'ffmpeg',
|
| 72 |
+
'-y', # 覆盖输出文件
|
| 73 |
+
'-i', output_wav, # 输入WAV文件
|
| 74 |
+
'-ss', str(start_time), # 开始时间
|
| 75 |
+
'-t', str(duration), # 持续时间
|
| 76 |
+
'-c', 'copy', # 直接复制,不重新编码
|
| 77 |
+
output_segment # 输出片段文件
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
print(f"执行命令: {' '.join(extract_cmd)}")
|
| 81 |
+
result = subprocess.run(extract_cmd, capture_output=True, text=True)
|
| 82 |
+
|
| 83 |
+
if result.returncode != 0:
|
| 84 |
+
print(f"❌ 音频片段提取失败: {result.stderr}")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
print(f"✅ 音频片段提取成功: {output_segment}")
|
| 88 |
+
|
| 89 |
+
# 显示文件信息
|
| 90 |
+
print(f"\n📊 文件信息:")
|
| 91 |
+
print(f"原始MP3文件: {input_file}")
|
| 92 |
+
print(f"转换后的WAV文件: {output_wav}")
|
| 93 |
+
print(f"提取的音频片段: {output_segment}")
|
| 94 |
+
print(f"片段时长: {duration:.1f}秒")
|
| 95 |
+
|
| 96 |
+
# 检查输出文件大小
|
| 97 |
+
if os.path.exists(output_wav):
|
| 98 |
+
wav_size = os.path.getsize(output_wav) / 1024 # KB
|
| 99 |
+
print(f"WAV文件大小: {wav_size:.1f} KB")
|
| 100 |
+
|
| 101 |
+
if os.path.exists(output_segment):
|
| 102 |
+
segment_size = os.path.getsize(output_segment) / 1024 # KB
|
| 103 |
+
print(f"片段文件大小: {segment_size:.1f} KB")
|
| 104 |
+
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"❌ 处理过程中出现错误: {str(e)}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
"""主函数"""
|
| 113 |
+
print("🎵 音频文件转换和片段提取工具")
|
| 114 |
+
print("=" * 50)
|
| 115 |
+
|
| 116 |
+
# 检查ffmpeg是否安装
|
| 117 |
+
try:
|
| 118 |
+
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
|
| 119 |
+
print("✅ 检测到ffmpeg")
|
| 120 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 121 |
+
print("❌ 错误:未找到ffmpeg,请先安装ffmpeg")
|
| 122 |
+
print("Ubuntu/Debian: sudo apt install ffmpeg")
|
| 123 |
+
print("CentOS/RHEL: sudo yum install ffmpeg")
|
| 124 |
+
print("macOS: brew install ffmpeg")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
# 设置文件路径和时间参数
|
| 128 |
+
input_file = "/home/t2vg-a100-G4-42/v-shuyuantu/StableAvatar/example_case/case-3/ssvid.net--Wellerman-Female-Cover-LYRICS-Sea-Shanty.mp3"
|
| 129 |
+
start_time = 1.9 # 开始时间(秒)
|
| 130 |
+
end_time = 7.1 # 结束时间(秒)
|
| 131 |
+
|
| 132 |
+
print(f"📁 输入文件: {input_file}")
|
| 133 |
+
print(f"⏰ 提取时间段: {start_time}s - {end_time}s")
|
| 134 |
+
print(f"⏱️ 片段时长: {end_time - start_time:.1f}秒")
|
| 135 |
+
|
| 136 |
+
# 执行转换和提取
|
| 137 |
+
success = convert_mp3_to_wav_and_extract(input_file, start_time, end_time)
|
| 138 |
+
|
| 139 |
+
if success:
|
| 140 |
+
print(f"\n🎉 所有操作完成!")
|
| 141 |
+
print(f"输出文件保存在: {os.path.dirname(input_file)}")
|
| 142 |
+
else:
|
| 143 |
+
print(f"\n❌ 操作失败,请检查错误信息")
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
lip_mask_extractor.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import mediapipe as mp
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument("--folder_root", type=str)
|
| 12 |
+
parser.add_argument("--start", type=int, help="Specify the value of start")
|
| 13 |
+
parser.add_argument("--end", type=int, help="Specify the value of end")
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
|
| 16 |
+
folder_root = args.folder_root
|
| 17 |
+
start = args.start
|
| 18 |
+
end = args.end
|
| 19 |
+
|
| 20 |
+
mp_face_mesh = mp.solutions.face_mesh
|
| 21 |
+
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=10)
|
| 22 |
+
|
| 23 |
+
upper_lip_idx = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291]
|
| 24 |
+
lower_lip_idx = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291]
|
| 25 |
+
|
| 26 |
+
for idx in range(start, end):
|
| 27 |
+
subfolder = str(idx).zfill(5)
|
| 28 |
+
subfolder_path = os.path.join(folder_root, subfolder)
|
| 29 |
+
images_folder = os.path.join(subfolder_path, "images")
|
| 30 |
+
if os.path.exists(images_folder):
|
| 31 |
+
face_masks_folder = os.path.join(subfolder_path, "lip_masks")
|
| 32 |
+
os.makedirs(face_masks_folder, exist_ok=True)
|
| 33 |
+
for root, dirs, files in os.walk(images_folder):
|
| 34 |
+
for file in files:
|
| 35 |
+
if file.endswith('.png'):
|
| 36 |
+
file_name = os.path.splitext(file)[0]
|
| 37 |
+
image_name = file_name + '.png'
|
| 38 |
+
image_legal_path = os.path.join(images_folder, image_name)
|
| 39 |
+
if os.path.exists(os.path.join(face_masks_folder, file_name + '.png')):
|
| 40 |
+
existed_path = os.path.join(face_masks_folder, file_name + '.png')
|
| 41 |
+
print(f"{existed_path} already exists!")
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
face_save_path = os.path.join(face_masks_folder, file_name + '.png')
|
| 45 |
+
|
| 46 |
+
image = cv2.imread(image_legal_path)
|
| 47 |
+
h, w, _ = image.shape
|
| 48 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 49 |
+
results = face_mesh.process(rgb_image)
|
| 50 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 51 |
+
|
| 52 |
+
if results.multi_face_landmarks:
|
| 53 |
+
for face_landmarks in results.multi_face_landmarks:
|
| 54 |
+
upper_points = np.array([
|
| 55 |
+
[int(face_landmarks.landmark[i].x * w), int(face_landmarks.landmark[i].y * h)]
|
| 56 |
+
for i in upper_lip_idx
|
| 57 |
+
], dtype=np.int32)
|
| 58 |
+
lower_points = np.array([
|
| 59 |
+
[int(face_landmarks.landmark[i].x * w), int(face_landmarks.landmark[i].y * h)]
|
| 60 |
+
for i in lower_lip_idx
|
| 61 |
+
], dtype=np.int32)
|
| 62 |
+
cv2.fillPoly(mask, [upper_points], 255)
|
| 63 |
+
cv2.fillPoly(mask, [lower_points], 255)
|
| 64 |
+
else:
|
| 65 |
+
print(f"No face detected in {image_legal_path}. Saving empty mask.")
|
| 66 |
+
cv2.imwrite(face_save_path, mask)
|
| 67 |
+
print(f"Lip mask saved to {face_save_path}")
|
| 68 |
+
else:
|
| 69 |
+
print(f"{images_folder} does not exist")
|
| 70 |
+
continue
|
requirements.txt
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
accelerate==1.10.0
|
| 3 |
+
aiofiles==24.1.0
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.12.15
|
| 6 |
+
aiosignal==1.4.0
|
| 7 |
+
albucore==0.0.24
|
| 8 |
+
albumentations==2.0.8
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
antlr4-python3-runtime==4.9.3
|
| 11 |
+
anyio==4.10.0
|
| 12 |
+
attrs==25.3.0
|
| 13 |
+
audio-separator==0.36.1
|
| 14 |
+
audioread==3.0.1
|
| 15 |
+
av==15.0.0
|
| 16 |
+
beartype==0.18.5
|
| 17 |
+
beautifulsoup4==4.13.4
|
| 18 |
+
Brotli==1.1.0
|
| 19 |
+
certifi==2025.8.3
|
| 20 |
+
cffi==1.17.1
|
| 21 |
+
charset-normalizer==3.4.3
|
| 22 |
+
click==8.2.1
|
| 23 |
+
coloredlogs==15.0.1
|
| 24 |
+
cryptography==45.0.6
|
| 25 |
+
Cython==3.1.3
|
| 26 |
+
dashscope==1.24.1
|
| 27 |
+
datasets==4.0.0
|
| 28 |
+
decorator==4.4.2
|
| 29 |
+
decord==0.6.0
|
| 30 |
+
diffq==0.2.4
|
| 31 |
+
diffusers==0.30.1
|
| 32 |
+
dill==0.3.8
|
| 33 |
+
easydict==1.13
|
| 34 |
+
einops==0.8.1
|
| 35 |
+
fastapi==0.116.1
|
| 36 |
+
ffmpy==0.6.1
|
| 37 |
+
filelock==3.13.1
|
| 38 |
+
flatbuffers==25.2.10
|
| 39 |
+
frozenlist==1.7.0
|
| 40 |
+
fsspec==2024.6.1
|
| 41 |
+
ftfy==6.3.1
|
| 42 |
+
gradio==5.42.0
|
| 43 |
+
gradio_client==1.11.1
|
| 44 |
+
groovy==0.1.2
|
| 45 |
+
grpcio==1.74.0
|
| 46 |
+
h11==0.16.0
|
| 47 |
+
hf-xet==1.1.7
|
| 48 |
+
httpcore==1.0.9
|
| 49 |
+
httpx==0.28.1
|
| 50 |
+
huggingface-hub==0.34.4
|
| 51 |
+
humanfriendly==10.0
|
| 52 |
+
idna==3.10
|
| 53 |
+
imageio==2.37.0
|
| 54 |
+
imageio-ffmpeg==0.6.0
|
| 55 |
+
importlib_metadata==8.7.0
|
| 56 |
+
Jinja2==3.1.4
|
| 57 |
+
joblib==1.5.1
|
| 58 |
+
julius==0.2.7
|
| 59 |
+
lazy_loader==0.4
|
| 60 |
+
librosa==0.11.0
|
| 61 |
+
llvmlite==0.44.0
|
| 62 |
+
Markdown==3.8.2
|
| 63 |
+
markdown-it-py==4.0.0
|
| 64 |
+
MarkupSafe==2.1.5
|
| 65 |
+
mdurl==0.1.2
|
| 66 |
+
ml_collections==1.1.0
|
| 67 |
+
ml_dtypes==0.5.3
|
| 68 |
+
moviepy==1.0.3
|
| 69 |
+
mpmath==1.3.0
|
| 70 |
+
msgpack==1.1.1
|
| 71 |
+
multidict==6.6.4
|
| 72 |
+
multiprocess==0.70.16
|
| 73 |
+
networkx==3.3
|
| 74 |
+
ninja==1.13.0
|
| 75 |
+
numba==0.61.2
|
| 76 |
+
numpy==2.2.6
|
| 77 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 78 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 79 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 80 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 81 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 82 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 83 |
+
nvidia-curand-cu12==10.3.5.147
|
| 84 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 85 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 86 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 87 |
+
nvidia-nccl-cu12==2.21.5
|
| 88 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 89 |
+
nvidia-nvtx-cu12==12.4.127
|
| 90 |
+
omegaconf==2.3.0
|
| 91 |
+
onnx-weekly==1.19.0.dev20250726
|
| 92 |
+
onnx2torch-py313==1.6.0
|
| 93 |
+
onnxruntime-gpu==1.22.0
|
| 94 |
+
opencv-python==4.11.0.86
|
| 95 |
+
opencv-python-headless==4.11.0.86
|
| 96 |
+
orjson==3.11.2
|
| 97 |
+
packaging==25.0
|
| 98 |
+
pandas==2.3.1
|
| 99 |
+
pillow==11.0.0
|
| 100 |
+
platformdirs==4.3.8
|
| 101 |
+
pooch==1.8.2
|
| 102 |
+
proglog==0.1.12
|
| 103 |
+
propcache==0.3.2
|
| 104 |
+
protobuf==6.31.1
|
| 105 |
+
psutil==7.0.0
|
| 106 |
+
pyarrow==21.0.0
|
| 107 |
+
pycparser==2.22
|
| 108 |
+
pydantic==2.11.7
|
| 109 |
+
pydantic_core==2.33.2
|
| 110 |
+
pydub==0.25.1
|
| 111 |
+
Pygments==2.19.2
|
| 112 |
+
python-dateutil==2.9.0.post0
|
| 113 |
+
python-dotenv==1.1.1
|
| 114 |
+
python-multipart==0.0.20
|
| 115 |
+
pytz==2025.2
|
| 116 |
+
PyYAML==6.0.2
|
| 117 |
+
regex==2025.7.34
|
| 118 |
+
requests==2.32.4
|
| 119 |
+
resampy==0.4.3
|
| 120 |
+
rich==14.1.0
|
| 121 |
+
rotary-embedding-torch==0.6.5
|
| 122 |
+
ruff==0.12.8
|
| 123 |
+
safehttpx==0.1.6
|
| 124 |
+
safetensors==0.6.2
|
| 125 |
+
samplerate==0.1.0
|
| 126 |
+
scikit-image==0.25.2
|
| 127 |
+
scikit-learn==1.7.1
|
| 128 |
+
scipy==1.16.1
|
| 129 |
+
semantic-version==2.10.0
|
| 130 |
+
sentencepiece==0.2.1
|
| 131 |
+
shellingham==1.5.4
|
| 132 |
+
simsimd==6.5.0
|
| 133 |
+
six==1.17.0
|
| 134 |
+
sniffio==1.3.1
|
| 135 |
+
soundfile==0.13.1
|
| 136 |
+
soupsieve==2.7
|
| 137 |
+
soxr==0.5.0.post1
|
| 138 |
+
starlette==0.47.2
|
| 139 |
+
stringzilla==3.12.6
|
| 140 |
+
sympy==1.13.1
|
| 141 |
+
tensorboard==2.20.0
|
| 142 |
+
tensorboard-data-server==0.7.2
|
| 143 |
+
threadpoolctl==3.6.0
|
| 144 |
+
tifffile==2025.6.11
|
| 145 |
+
timm==1.0.19
|
| 146 |
+
tokenizers==0.21.4
|
| 147 |
+
tomesd==0.1.3
|
| 148 |
+
tomlkit==0.13.3
|
| 149 |
+
torch==2.6.0+cu124
|
| 150 |
+
torchaudio==2.6.0+cu124
|
| 151 |
+
torchdiffeq==0.2.5
|
| 152 |
+
torchsde==0.2.6
|
| 153 |
+
torchvision==0.21.0+cu124
|
| 154 |
+
tqdm==4.67.1
|
| 155 |
+
trampoline==0.1.2
|
| 156 |
+
transformers==4.51.3
|
| 157 |
+
triton==3.2.0
|
| 158 |
+
typer==0.16.0
|
| 159 |
+
typing-inspection==0.4.1
|
| 160 |
+
typing_extensions==4.12.2
|
| 161 |
+
tzdata==2025.2
|
| 162 |
+
urllib3==2.5.0
|
| 163 |
+
uvicorn==0.35.0
|
| 164 |
+
wcwidth==0.2.13
|
| 165 |
+
websocket-client==1.8.0
|
| 166 |
+
websockets==15.0.1
|
| 167 |
+
Werkzeug==3.1.3
|
| 168 |
+
xxhash==3.5.0
|
| 169 |
+
yarl==1.20.1
|
| 170 |
+
zipp==3.23.0
|
vocal_seperator.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import subprocess
|
| 5 |
+
from audio_separator.separator import Separator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--audio_file_path", type=str)
|
| 11 |
+
parser.add_argument("--saved_vocal_path", type=str)
|
| 12 |
+
parser.add_argument("--audio_separator_model_file", type=str)
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
audio_file_path = args.audio_file_path
|
| 15 |
+
audio_separator_model_file = args.audio_separator_model_file
|
| 16 |
+
saved_vocal_path = args.saved_vocal_path
|
| 17 |
+
cache_dir = os.path.join(os.path.dirname(audio_file_path), "vocals")
|
| 18 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 19 |
+
audio_separator = Separator(
|
| 20 |
+
output_dir=cache_dir,
|
| 21 |
+
output_single_stem="vocals",
|
| 22 |
+
model_file_dir=os.path.dirname(audio_separator_model_file),
|
| 23 |
+
)
|
| 24 |
+
audio_separator.load_model(os.path.basename(audio_separator_model_file))
|
| 25 |
+
assert audio_separator.model_instance is not None, "Fail to load audio separate model."
|
| 26 |
+
outputs = audio_separator.separate(audio_file_path)
|
| 27 |
+
subfolder_path = os.path.dirname(audio_file_path)
|
| 28 |
+
vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0])
|
| 29 |
+
destination_file = os.path.join(subfolder_path, "vocal.wav")
|
| 30 |
+
shutil.copy(vocal_audio_file, destination_file)
|
| 31 |
+
os.remove(vocal_audio_file)
|
wan/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from . import configs, distributed, modules
|
| 2 |
+
# from .image2video import WanI2V
|
| 3 |
+
# from .text2video import WanT2V
|
wan/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
wan/configs/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
|
| 7 |
+
from .wan_i2v_14B import i2v_14B
|
| 8 |
+
from .wan_t2v_1_3B import t2v_1_3B
|
| 9 |
+
from .wan_t2v_14B import t2v_14B
|
| 10 |
+
|
| 11 |
+
# the config of t2i_14B is the same as t2v_14B
|
| 12 |
+
t2i_14B = copy.deepcopy(t2v_14B)
|
| 13 |
+
t2i_14B.__name__ = 'Config: Wan T2I 14B'
|
| 14 |
+
|
| 15 |
+
WAN_CONFIGS = {
|
| 16 |
+
't2v-14B': t2v_14B,
|
| 17 |
+
't2v-1.3B': t2v_1_3B,
|
| 18 |
+
'i2v-14B': i2v_14B,
|
| 19 |
+
't2i-14B': t2i_14B,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
SIZE_CONFIGS = {
|
| 23 |
+
'720*1280': (720, 1280),
|
| 24 |
+
'1280*720': (1280, 720),
|
| 25 |
+
'480*832': (480, 832),
|
| 26 |
+
'832*480': (832, 480),
|
| 27 |
+
'1024*1024': (1024, 1024),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
MAX_AREA_CONFIGS = {
|
| 31 |
+
'720*1280': 720 * 1280,
|
| 32 |
+
'1280*720': 1280 * 720,
|
| 33 |
+
'480*832': 480 * 832,
|
| 34 |
+
'832*480': 832 * 480,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
SUPPORTED_SIZES = {
|
| 38 |
+
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 39 |
+
't2v-1.3B': ('480*832', '832*480'),
|
| 40 |
+
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 41 |
+
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
| 42 |
+
}
|
wan/configs/shared_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
#------------------------ Wan shared config ------------------------#
|
| 6 |
+
wan_shared_cfg = EasyDict()
|
| 7 |
+
|
| 8 |
+
# t5
|
| 9 |
+
wan_shared_cfg.t5_model = 'umt5_xxl'
|
| 10 |
+
wan_shared_cfg.t5_dtype = torch.bfloat16
|
| 11 |
+
wan_shared_cfg.text_len = 512
|
| 12 |
+
|
| 13 |
+
# transformer
|
| 14 |
+
wan_shared_cfg.param_dtype = torch.bfloat16
|
| 15 |
+
|
| 16 |
+
# inference
|
| 17 |
+
wan_shared_cfg.num_train_timesteps = 1000
|
| 18 |
+
wan_shared_cfg.sample_fps = 16
|
| 19 |
+
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
wan/configs/wan_i2v_14B.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
#------------------------ Wan I2V 14B ------------------------#
|
| 8 |
+
|
| 9 |
+
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
|
| 10 |
+
i2v_14B.update(wan_shared_cfg)
|
| 11 |
+
|
| 12 |
+
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# clip
|
| 16 |
+
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
| 17 |
+
i2v_14B.clip_dtype = torch.float16
|
| 18 |
+
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
| 19 |
+
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
| 20 |
+
|
| 21 |
+
# vae
|
| 22 |
+
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 23 |
+
i2v_14B.vae_stride = (4, 8, 8)
|
| 24 |
+
|
| 25 |
+
# transformer
|
| 26 |
+
i2v_14B.patch_size = (1, 2, 2)
|
| 27 |
+
i2v_14B.dim = 5120
|
| 28 |
+
i2v_14B.ffn_dim = 13824
|
| 29 |
+
i2v_14B.freq_dim = 256
|
| 30 |
+
i2v_14B.num_heads = 40
|
| 31 |
+
i2v_14B.num_layers = 40
|
| 32 |
+
i2v_14B.window_size = (-1, -1)
|
| 33 |
+
i2v_14B.qk_norm = True
|
| 34 |
+
i2v_14B.cross_attn_norm = True
|
| 35 |
+
i2v_14B.eps = 1e-6
|
wan/configs/wan_t2v_14B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V 14B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
|
| 9 |
+
t2v_14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_14B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_14B.dim = 5120
|
| 22 |
+
t2v_14B.ffn_dim = 13824
|
| 23 |
+
t2v_14B.freq_dim = 256
|
| 24 |
+
t2v_14B.num_heads = 40
|
| 25 |
+
t2v_14B.num_layers = 40
|
| 26 |
+
t2v_14B.window_size = (-1, -1)
|
| 27 |
+
t2v_14B.qk_norm = True
|
| 28 |
+
t2v_14B.cross_attn_norm = True
|
| 29 |
+
t2v_14B.eps = 1e-6
|
wan/configs/wan_t2v_1_3B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V 1.3B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
|
| 9 |
+
t2v_1_3B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_1_3B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_1_3B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_1_3B.dim = 1536
|
| 22 |
+
t2v_1_3B.ffn_dim = 8960
|
| 23 |
+
t2v_1_3B.freq_dim = 256
|
| 24 |
+
t2v_1_3B.num_heads = 12
|
| 25 |
+
t2v_1_3B.num_layers = 30
|
| 26 |
+
t2v_1_3B.window_size = (-1, -1)
|
| 27 |
+
t2v_1_3B.qk_norm = True
|
| 28 |
+
t2v_1_3B.cross_attn_norm = True
|
| 29 |
+
t2v_1_3B.eps = 1e-6
|
wan/dataset/talking_video_dataset_fantasy.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
from torch.utils.data.dataset import Dataset
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_random_mask(shape, image_start_only=False):
|
| 17 |
+
f, c, h, w = shape
|
| 18 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 19 |
+
|
| 20 |
+
if not image_start_only:
|
| 21 |
+
if f != 1:
|
| 22 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
|
| 23 |
+
else:
|
| 24 |
+
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
|
| 25 |
+
if mask_index == 0:
|
| 26 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 27 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 28 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 29 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 30 |
+
|
| 31 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 32 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 33 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 34 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 35 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 36 |
+
elif mask_index == 1:
|
| 37 |
+
mask[:, :, :, :] = 1
|
| 38 |
+
elif mask_index == 2:
|
| 39 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 40 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 41 |
+
elif mask_index == 3:
|
| 42 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 43 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 44 |
+
elif mask_index == 4:
|
| 45 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 46 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 47 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 48 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 49 |
+
|
| 50 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 51 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 52 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 53 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 54 |
+
|
| 55 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 56 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 57 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 58 |
+
elif mask_index == 5:
|
| 59 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 60 |
+
elif mask_index == 6:
|
| 61 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 62 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 63 |
+
|
| 64 |
+
for i in frames_to_mask:
|
| 65 |
+
block_height = random.randint(1, h // 4)
|
| 66 |
+
block_width = random.randint(1, w // 4)
|
| 67 |
+
top_left_y = random.randint(0, h - block_height)
|
| 68 |
+
top_left_x = random.randint(0, w - block_width)
|
| 69 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 70 |
+
elif mask_index == 7:
|
| 71 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 72 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 73 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
|
| 74 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
|
| 75 |
+
|
| 76 |
+
for i in range(h):
|
| 77 |
+
for j in range(w):
|
| 78 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 79 |
+
mask[:, :, i, j] = 1
|
| 80 |
+
elif mask_index == 8:
|
| 81 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 82 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 83 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 84 |
+
for i in range(h):
|
| 85 |
+
for j in range(w):
|
| 86 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 87 |
+
mask[:, :, i, j] = 1
|
| 88 |
+
elif mask_index == 9:
|
| 89 |
+
for idx in range(f):
|
| 90 |
+
if np.random.rand() > 0.5:
|
| 91 |
+
mask[idx, :, :, :] = 1
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 94 |
+
else:
|
| 95 |
+
if f != 1:
|
| 96 |
+
mask[1:, :, :, :] = 1
|
| 97 |
+
else:
|
| 98 |
+
mask[:, :, :, :] = 1
|
| 99 |
+
return mask
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class LargeScaleTalkingFantasyVideos(Dataset):
|
| 103 |
+
def __init__(self, txt_path, width, height, n_sample_frames, sample_frame_rate, only_last_features=False, vocal_encoder=None, audio_encoder=None, vocal_sample_rate=16000, audio_sample_rate=24000, enable_inpaint=True, audio_margin=2, vae_stride=None, patch_size=None, wav2vec_processor=None, wav2vec=None):
|
| 104 |
+
self.txt_path = txt_path
|
| 105 |
+
self.width = width
|
| 106 |
+
self.height = height
|
| 107 |
+
self.n_sample_frames = n_sample_frames
|
| 108 |
+
self.sample_frame_rate = sample_frame_rate
|
| 109 |
+
self.only_last_features = only_last_features
|
| 110 |
+
self.vocal_encoder = vocal_encoder
|
| 111 |
+
self.audio_encoder = audio_encoder
|
| 112 |
+
self.vocal_sample_rate = vocal_sample_rate
|
| 113 |
+
self.audio_sample_rate = audio_sample_rate
|
| 114 |
+
self.enable_inpaint = enable_inpaint
|
| 115 |
+
self.wav2vec_processor = wav2vec_processor
|
| 116 |
+
self.audio_margin = audio_margin
|
| 117 |
+
self.vae_stride = vae_stride
|
| 118 |
+
self.patch_size = patch_size
|
| 119 |
+
self.max_area = height * width
|
| 120 |
+
self.aspect_ratio = height / width
|
| 121 |
+
self.video_files = self._read_txt_file_images()
|
| 122 |
+
|
| 123 |
+
self.lat_h = round(
|
| 124 |
+
np.sqrt(self.max_area * self.aspect_ratio) // self.vae_stride[1] //
|
| 125 |
+
self.patch_size[1] * self.patch_size[1])
|
| 126 |
+
self.lat_w = round(
|
| 127 |
+
np.sqrt(self.max_area / self.aspect_ratio) // self.vae_stride[2] //
|
| 128 |
+
self.patch_size[2] * self.patch_size[2])
|
| 129 |
+
|
| 130 |
+
def _read_txt_file_images(self):
|
| 131 |
+
with open(self.txt_path, 'r') as file:
|
| 132 |
+
lines = file.readlines()
|
| 133 |
+
video_files = []
|
| 134 |
+
for line in lines:
|
| 135 |
+
video_file = line.strip()
|
| 136 |
+
video_files.append(video_file)
|
| 137 |
+
return video_files
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.video_files)
|
| 141 |
+
|
| 142 |
+
def frame_count(self, frames_path):
|
| 143 |
+
files = os.listdir(frames_path)
|
| 144 |
+
png_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')]
|
| 145 |
+
png_files_count = len(png_files)
|
| 146 |
+
return png_files_count
|
| 147 |
+
|
| 148 |
+
def find_frames_list(self, frames_path):
|
| 149 |
+
files = os.listdir(frames_path)
|
| 150 |
+
image_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')]
|
| 151 |
+
if image_files[0].startswith('frame_'):
|
| 152 |
+
image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
|
| 153 |
+
else:
|
| 154 |
+
image_files.sort(key=lambda x: int(x.split('.')[0]))
|
| 155 |
+
return image_files
|
| 156 |
+
|
| 157 |
+
def __getitem__(self, idx):
|
| 158 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
| 159 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 160 |
+
|
| 161 |
+
video_path = os.path.join(self.video_files[idx], "sub_clip.mp4")
|
| 162 |
+
cap = cv2.VideoCapture(video_path)
|
| 163 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 164 |
+
try:
|
| 165 |
+
is_0_fps = 2 / fps
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"The fps of {video_path} is 0 !!!")
|
| 168 |
+
vocal_audio_path = os.path.join(self.video_files[idx], "audio.wav")
|
| 169 |
+
vocal_duration = librosa.get_duration(filename=vocal_audio_path)
|
| 170 |
+
frames_path = os.path.join(self.video_files[idx], "images")
|
| 171 |
+
total_frame_number = self.frame_count(frames_path)
|
| 172 |
+
fps = total_frame_number / vocal_duration
|
| 173 |
+
print(f"The calculated fps of {video_path} is {fps} !!!")
|
| 174 |
+
# idx = random.randint(0, len(self.video_files) - 1)
|
| 175 |
+
# video_path = os.path.join(self.video_files[idx], "sub_clip.mp4")
|
| 176 |
+
# cap = cv2.VideoCapture(video_path)
|
| 177 |
+
# fps = cap.get(cv2.CAP_PROP_FPS)
|
| 178 |
+
|
| 179 |
+
frames_path = os.path.join(self.video_files[idx], "images")
|
| 180 |
+
|
| 181 |
+
face_masks_path = os.path.join(self.video_files[idx], "face_masks")
|
| 182 |
+
lip_masks_path = os.path.join(self.video_files[idx], "lip_masks")
|
| 183 |
+
raw_audio_path = os.path.join(self.video_files[idx], "audio.wav")
|
| 184 |
+
# vocal_audio_path = os.path.join(self.video_files[idx], "vocal.wav")
|
| 185 |
+
vocal_audio_path = os.path.join(self.video_files[idx], "audio.wav")
|
| 186 |
+
video_length = self.frame_count(frames_path)
|
| 187 |
+
frames_list = self.find_frames_list(frames_path)
|
| 188 |
+
|
| 189 |
+
clip_length = min(video_length, (self.n_sample_frames - 1) * self.sample_frame_rate + 1)
|
| 190 |
+
|
| 191 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 192 |
+
batch_index = np.linspace(
|
| 193 |
+
start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
|
| 194 |
+
).tolist()
|
| 195 |
+
all_indices = list(range(0, video_length))
|
| 196 |
+
reference_frame_idx = random.choice(all_indices)
|
| 197 |
+
|
| 198 |
+
tgt_pil_image_list = []
|
| 199 |
+
tgt_face_masks_list = []
|
| 200 |
+
tgt_lip_masks_list = []
|
| 201 |
+
|
| 202 |
+
# reference_frame_path = os.path.join(frames_path, frames_list[reference_frame_idx])
|
| 203 |
+
reference_frame_path = os.path.join(frames_path, frames_list[start_idx])
|
| 204 |
+
reference_pil_image = Image.open(reference_frame_path).convert('RGB')
|
| 205 |
+
reference_pil_image = reference_pil_image.resize((self.width, self.height))
|
| 206 |
+
reference_pil_image = torch.from_numpy(np.array(reference_pil_image)).float()
|
| 207 |
+
reference_pil_image = reference_pil_image / 127.5 - 1
|
| 208 |
+
|
| 209 |
+
for index in batch_index:
|
| 210 |
+
tgt_img_path = os.path.join(frames_path, frames_list[index])
|
| 211 |
+
# file_name = os.path.splitext(os.path.basename(tgt_img_path))[0]
|
| 212 |
+
file_name = os.path.basename(tgt_img_path)
|
| 213 |
+
face_mask_path = os.path.join(face_masks_path, file_name)
|
| 214 |
+
lip_mask_path = os.path.join(lip_masks_path, file_name)
|
| 215 |
+
try:
|
| 216 |
+
tgt_img_pil = Image.open(tgt_img_path).convert('RGB')
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Fail loading the image: {tgt_img_path}")
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
tgt_lip_mask = Image.open(lip_mask_path)
|
| 222 |
+
# tgt_lip_mask = Image.open(lip_mask_path).convert('RGB')
|
| 223 |
+
tgt_lip_mask = tgt_lip_mask.resize((self.width, self.height))
|
| 224 |
+
tgt_lip_mask = torch.from_numpy(np.array(tgt_lip_mask)).float()
|
| 225 |
+
# tgt_lip_mask = tgt_lip_mask / 127.5 - 1
|
| 226 |
+
tgt_lip_mask = tgt_lip_mask / 255
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"Fail loading the lip masks: {lip_mask_path}")
|
| 229 |
+
tgt_lip_mask = torch.ones(self.height, self.width)
|
| 230 |
+
# tgt_lip_mask = torch.ones(self.height, self.width, 3)
|
| 231 |
+
tgt_lip_masks_list.append(tgt_lip_mask)
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
tgt_face_mask = Image.open(face_mask_path)
|
| 235 |
+
# tgt_face_mask = Image.open(face_mask_path).convert('RGB')
|
| 236 |
+
tgt_face_mask = tgt_face_mask.resize((self.width, self.height))
|
| 237 |
+
tgt_face_mask = torch.from_numpy(np.array(tgt_face_mask)).float()
|
| 238 |
+
tgt_face_mask = tgt_face_mask / 255
|
| 239 |
+
# tgt_face_mask = tgt_face_mask / 127.5 - 1
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"Fail loading the face masks: {face_mask_path}")
|
| 242 |
+
tgt_face_mask = torch.ones(self.height, self.width)
|
| 243 |
+
# tgt_face_mask = torch.ones(self.height, self.width, 3)
|
| 244 |
+
tgt_face_masks_list.append(tgt_face_mask)
|
| 245 |
+
|
| 246 |
+
tgt_img_pil = tgt_img_pil.resize((self.width, self.height))
|
| 247 |
+
tgt_img_tensor = torch.from_numpy(np.array(tgt_img_pil)).float()
|
| 248 |
+
tgt_img_normalized = tgt_img_tensor / 127.5 - 1
|
| 249 |
+
tgt_pil_image_list.append(tgt_img_normalized)
|
| 250 |
+
|
| 251 |
+
sr = 16000
|
| 252 |
+
vocal_input, sample_rate = librosa.load(vocal_audio_path, sr=sr)
|
| 253 |
+
vocal_duration = librosa.get_duration(filename=vocal_audio_path)
|
| 254 |
+
start_time = batch_index[0] / fps
|
| 255 |
+
end_time = (clip_length / fps) + start_time
|
| 256 |
+
start_sample = int(start_time * sr)
|
| 257 |
+
end_sample = int(end_time * sr)
|
| 258 |
+
try:
|
| 259 |
+
vocal_segment = vocal_input[start_sample:end_sample]
|
| 260 |
+
except:
|
| 261 |
+
print(f"The current vocal segment is too short: {vocal_audio_path}, [{batch_index[0]}, {batch_index[-1]}], fps={fps}, clip_length={clip_length}, vocal_duration={vocal_duration}], [{start_time}, {end_time}]")
|
| 262 |
+
vocal_segment = vocal_input[start_sample:]
|
| 263 |
+
vocal_input_values = self.wav2vec_processor(
|
| 264 |
+
vocal_segment, sampling_rate=sample_rate, return_tensors="pt"
|
| 265 |
+
).input_values
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
tgt_pil_image_list = torch.stack(tgt_pil_image_list, dim=0)
|
| 269 |
+
tgt_pil_image_list = rearrange(tgt_pil_image_list, "f h w c -> f c h w")
|
| 270 |
+
reference_pil_image = rearrange(reference_pil_image, "h w c -> c h w")
|
| 271 |
+
|
| 272 |
+
tgt_face_masks_list = torch.stack(tgt_face_masks_list, dim=0)
|
| 273 |
+
tgt_face_masks_list = torch.unsqueeze(tgt_face_masks_list, dim=-1)
|
| 274 |
+
tgt_face_masks_list = rearrange(tgt_face_masks_list, "f h w c -> c f h w")
|
| 275 |
+
tgt_lip_masks_list = torch.stack(tgt_lip_masks_list, dim=0)
|
| 276 |
+
tgt_lip_masks_list = torch.unsqueeze(tgt_lip_masks_list, dim=-1)
|
| 277 |
+
tgt_lip_masks_list = rearrange(tgt_lip_masks_list, "f h w c -> c f h w")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
clip_pixel_values = reference_pil_image.permute(1, 2, 0).contiguous()
|
| 281 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 282 |
+
|
| 283 |
+
cos_similarities = []
|
| 284 |
+
stride = 8
|
| 285 |
+
for i in range(0, tgt_pil_image_list.size()[0] - stride, stride):
|
| 286 |
+
frame1 = tgt_pil_image_list[i]
|
| 287 |
+
frame2 = tgt_pil_image_list[i + stride]
|
| 288 |
+
frame1_flat = frame1.contiguous().view(-1)
|
| 289 |
+
frame2_flat = frame2.contiguous().view(-1)
|
| 290 |
+
cos_sim = F.cosine_similarity(frame1_flat, frame2_flat, dim=0)
|
| 291 |
+
cos_sim = (cos_sim + 1) / 2
|
| 292 |
+
cos_similarities.append(cos_sim.item())
|
| 293 |
+
overall_cos_sim = F.cosine_similarity(tgt_pil_image_list[0].contiguous().view(-1), tgt_pil_image_list[-1].contiguous().view(-1), dim=0)
|
| 294 |
+
overall_cos_sim = (overall_cos_sim + 1) / 2
|
| 295 |
+
cos_similarities.append(overall_cos_sim.item())
|
| 296 |
+
motion_id = (1.0 - sum(cos_similarities) / len(cos_similarities)) * 100
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if "singing" in self.video_files[idx]:
|
| 300 |
+
text_prompt = "The protagonist is singing"
|
| 301 |
+
elif "speech" in self.video_files[idx]:
|
| 302 |
+
text_prompt = "The protagonist is talking"
|
| 303 |
+
elif "dancing" in self.video_files[idx]:
|
| 304 |
+
text_prompt = "The protagonist is simultaneously dancing and singing"
|
| 305 |
+
else:
|
| 306 |
+
text_prompt = ""
|
| 307 |
+
print(1 / 0)
|
| 308 |
+
|
| 309 |
+
sample = dict(
|
| 310 |
+
pixel_values=tgt_pil_image_list,
|
| 311 |
+
reference_image=reference_pil_image,
|
| 312 |
+
clip_pixel_values=clip_pixel_values,
|
| 313 |
+
tgt_face_masks=tgt_face_masks_list,
|
| 314 |
+
vocal_input_values=vocal_input_values,
|
| 315 |
+
text_prompt=text_prompt,
|
| 316 |
+
motion_id=motion_id,
|
| 317 |
+
tgt_lip_masks=tgt_lip_masks_list,
|
| 318 |
+
audio_path=raw_audio_path,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if self.enable_inpaint:
|
| 322 |
+
pixel_value_masks = get_random_mask(tgt_pil_image_list.size(), image_start_only=True)
|
| 323 |
+
masked_pixel_values = tgt_pil_image_list * (1-pixel_value_masks)
|
| 324 |
+
sample["masked_pixel_values"] = masked_pixel_values
|
| 325 |
+
sample["pixel_value_masks"] = pixel_value_masks
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
return sample
|
wan/dist/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import xfuser
|
| 6 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 7 |
+
get_sequence_parallel_world_size,
|
| 8 |
+
get_sp_group, get_world_group,
|
| 9 |
+
init_distributed_environment,
|
| 10 |
+
initialize_model_parallel)
|
| 11 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 12 |
+
except Exception as ex:
|
| 13 |
+
get_sequence_parallel_world_size = None
|
| 14 |
+
get_sequence_parallel_rank = None
|
| 15 |
+
xFuserLongContextAttention = None
|
| 16 |
+
get_sp_group = None
|
| 17 |
+
get_world_group = None
|
| 18 |
+
init_distributed_environment = None
|
| 19 |
+
initialize_model_parallel = None
|
| 20 |
+
|
| 21 |
+
def set_multi_gpus_devices(ulysses_degree, ring_degree):
|
| 22 |
+
if ulysses_degree > 1 or ring_degree > 1:
|
| 23 |
+
if get_sp_group is None:
|
| 24 |
+
raise RuntimeError("xfuser is not installed.")
|
| 25 |
+
dist.init_process_group("nccl")
|
| 26 |
+
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
|
| 27 |
+
ulysses_degree, ring_degree, dist.get_rank(),
|
| 28 |
+
dist.get_world_size()))
|
| 29 |
+
assert dist.get_world_size() == ring_degree * ulysses_degree, \
|
| 30 |
+
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
|
| 31 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 32 |
+
initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
|
| 33 |
+
ring_degree=ring_degree,
|
| 34 |
+
ulysses_degree=ulysses_degree)
|
| 35 |
+
# device = torch.device("cuda:%d" % dist.get_rank())
|
| 36 |
+
device = torch.device(f"cuda:{get_world_group().local_rank}")
|
| 37 |
+
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
|
| 38 |
+
else:
|
| 39 |
+
device = "cuda"
|
| 40 |
+
return device
|
wan/dist/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
wan/dist/__pycache__/wan_xfuser.cpython-311.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
wan/dist/wan_xfuser.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.amp as amp
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import xfuser
|
| 6 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 7 |
+
get_sequence_parallel_world_size,
|
| 8 |
+
get_sp_group,
|
| 9 |
+
init_distributed_environment,
|
| 10 |
+
initialize_model_parallel)
|
| 11 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 12 |
+
except Exception as ex:
|
| 13 |
+
get_sequence_parallel_world_size = None
|
| 14 |
+
get_sequence_parallel_rank = None
|
| 15 |
+
xFuserLongContextAttention = None
|
| 16 |
+
get_sp_group = None
|
| 17 |
+
init_distributed_environment = None
|
| 18 |
+
initialize_model_parallel = None
|
| 19 |
+
|
| 20 |
+
def pad_freqs(original_tensor, target_len):
|
| 21 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 22 |
+
pad_size = target_len - seq_len
|
| 23 |
+
padding_tensor = torch.ones(
|
| 24 |
+
pad_size,
|
| 25 |
+
s1,
|
| 26 |
+
s2,
|
| 27 |
+
dtype=original_tensor.dtype,
|
| 28 |
+
device=original_tensor.device)
|
| 29 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 30 |
+
return padded_tensor
|
| 31 |
+
|
| 32 |
+
@amp.autocast('cuda', enabled=False)
|
| 33 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 34 |
+
"""
|
| 35 |
+
x: [B, L, N, C].
|
| 36 |
+
grid_sizes: [B, 3].
|
| 37 |
+
freqs: [M, C // 2].
|
| 38 |
+
"""
|
| 39 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 40 |
+
# split freqs
|
| 41 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 42 |
+
|
| 43 |
+
# loop over samples
|
| 44 |
+
output = []
|
| 45 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 46 |
+
seq_len = f * h * w
|
| 47 |
+
|
| 48 |
+
# precompute multipliers
|
| 49 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
|
| 50 |
+
s, n, -1, 2))
|
| 51 |
+
freqs_i = torch.cat([
|
| 52 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 53 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 54 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 55 |
+
],
|
| 56 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 57 |
+
|
| 58 |
+
# apply rotary embedding
|
| 59 |
+
sp_size = get_sequence_parallel_world_size()
|
| 60 |
+
sp_rank = get_sequence_parallel_rank()
|
| 61 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 62 |
+
s_per_rank = s
|
| 63 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 64 |
+
s_per_rank), :, :]
|
| 65 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 66 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 67 |
+
|
| 68 |
+
# append to collection
|
| 69 |
+
output.append(x_i)
|
| 70 |
+
return torch.stack(output)
|
| 71 |
+
|
| 72 |
+
def usp_attn_forward(self,
|
| 73 |
+
x,
|
| 74 |
+
seq_lens,
|
| 75 |
+
grid_sizes,
|
| 76 |
+
freqs,
|
| 77 |
+
dtype=torch.bfloat16):
|
| 78 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 79 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 80 |
+
|
| 81 |
+
def half(x):
|
| 82 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 83 |
+
|
| 84 |
+
# query, key, value function
|
| 85 |
+
def qkv_fn(x):
|
| 86 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 87 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 88 |
+
v = self.v(x).view(b, s, n, d)
|
| 89 |
+
return q, k, v
|
| 90 |
+
|
| 91 |
+
q, k, v = qkv_fn(x)
|
| 92 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 93 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 94 |
+
|
| 95 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 96 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 97 |
+
# if k_lens is not None:
|
| 98 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 99 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 100 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 101 |
+
|
| 102 |
+
x = xFuserLongContextAttention()(
|
| 103 |
+
None,
|
| 104 |
+
query=half(q),
|
| 105 |
+
key=half(k),
|
| 106 |
+
value=half(v),
|
| 107 |
+
window_size=self.window_size)
|
| 108 |
+
|
| 109 |
+
# TODO: padding after attention.
|
| 110 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 111 |
+
|
| 112 |
+
# output
|
| 113 |
+
x = x.flatten(2)
|
| 114 |
+
x = self.o(x)
|
| 115 |
+
return x
|
wan/distributed/__init__.py
ADDED
|
File without changes
|
wan/distributed/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
wan/distributed/__pycache__/fsdp.cpython-311.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
wan/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 9 |
+
from torch.distributed.utils import _free_storage
|
| 10 |
+
|
| 11 |
+
def shard_model(
|
| 12 |
+
model,
|
| 13 |
+
device_id,
|
| 14 |
+
param_dtype=torch.bfloat16,
|
| 15 |
+
reduce_dtype=torch.float32,
|
| 16 |
+
buffer_dtype=torch.float32,
|
| 17 |
+
process_group=None,
|
| 18 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 19 |
+
sync_module_states=True,
|
| 20 |
+
):
|
| 21 |
+
model = FSDP(
|
| 22 |
+
module=model,
|
| 23 |
+
process_group=process_group,
|
| 24 |
+
sharding_strategy=sharding_strategy,
|
| 25 |
+
auto_wrap_policy=partial(
|
| 26 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 27 |
+
mixed_precision=MixedPrecision(
|
| 28 |
+
param_dtype=param_dtype,
|
| 29 |
+
reduce_dtype=reduce_dtype,
|
| 30 |
+
buffer_dtype=buffer_dtype),
|
| 31 |
+
device_id=device_id,
|
| 32 |
+
sync_module_states=sync_module_states)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
def free_model(model):
|
| 36 |
+
for m in model.modules():
|
| 37 |
+
if isinstance(m, FSDP):
|
| 38 |
+
_free_storage(m._handle.flat_param.data)
|
| 39 |
+
del model
|
| 40 |
+
gc.collect()
|
| 41 |
+
torch.cuda.empty_cache()
|
wan/distributed/xdit_context_parallel.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.cuda.amp as amp
|
| 4 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 5 |
+
get_sequence_parallel_world_size,
|
| 6 |
+
get_sp_group)
|
| 7 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
from ..modules.model import sinusoidal_embedding_1d
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pad_freqs(original_tensor, target_len):
|
| 13 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 14 |
+
pad_size = target_len - seq_len
|
| 15 |
+
padding_tensor = torch.ones(
|
| 16 |
+
pad_size,
|
| 17 |
+
s1,
|
| 18 |
+
s2,
|
| 19 |
+
dtype=original_tensor.dtype,
|
| 20 |
+
device=original_tensor.device)
|
| 21 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 22 |
+
return padded_tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@amp.autocast(enabled=False)
|
| 26 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 27 |
+
"""
|
| 28 |
+
x: [B, L, N, C].
|
| 29 |
+
grid_sizes: [B, 3].
|
| 30 |
+
freqs: [M, C // 2].
|
| 31 |
+
"""
|
| 32 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 33 |
+
# split freqs
|
| 34 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 35 |
+
|
| 36 |
+
# loop over samples
|
| 37 |
+
output = []
|
| 38 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 39 |
+
seq_len = f * h * w
|
| 40 |
+
|
| 41 |
+
# precompute multipliers
|
| 42 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 43 |
+
s, n, -1, 2))
|
| 44 |
+
freqs_i = torch.cat([
|
| 45 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 46 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 47 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 48 |
+
],
|
| 49 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 50 |
+
|
| 51 |
+
# apply rotary embedding
|
| 52 |
+
sp_size = get_sequence_parallel_world_size()
|
| 53 |
+
sp_rank = get_sequence_parallel_rank()
|
| 54 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 55 |
+
s_per_rank = s
|
| 56 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 57 |
+
s_per_rank), :, :]
|
| 58 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 59 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 60 |
+
|
| 61 |
+
# append to collection
|
| 62 |
+
output.append(x_i)
|
| 63 |
+
return torch.stack(output).float()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def usp_dit_forward(
|
| 67 |
+
self,
|
| 68 |
+
x,
|
| 69 |
+
t,
|
| 70 |
+
context,
|
| 71 |
+
seq_len,
|
| 72 |
+
clip_fea=None,
|
| 73 |
+
y=None,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 77 |
+
t: [B].
|
| 78 |
+
context: A list of text embeddings each with shape [L, C].
|
| 79 |
+
"""
|
| 80 |
+
if self.model_type == 'i2v':
|
| 81 |
+
assert clip_fea is not None and y is not None
|
| 82 |
+
# params
|
| 83 |
+
device = self.patch_embedding.weight.device
|
| 84 |
+
if self.freqs.device != device:
|
| 85 |
+
self.freqs = self.freqs.to(device)
|
| 86 |
+
|
| 87 |
+
if y is not None:
|
| 88 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 89 |
+
|
| 90 |
+
# embeddings
|
| 91 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 92 |
+
grid_sizes = torch.stack(
|
| 93 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 94 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 95 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 96 |
+
assert seq_lens.max() <= seq_len
|
| 97 |
+
x = torch.cat([
|
| 98 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 99 |
+
for u in x
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
# time embeddings
|
| 103 |
+
with amp.autocast(dtype=torch.float32):
|
| 104 |
+
e = self.time_embedding(
|
| 105 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 106 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 107 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 108 |
+
|
| 109 |
+
# context
|
| 110 |
+
context_lens = None
|
| 111 |
+
context = self.text_embedding(
|
| 112 |
+
torch.stack([
|
| 113 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 114 |
+
for u in context
|
| 115 |
+
]))
|
| 116 |
+
|
| 117 |
+
if clip_fea is not None:
|
| 118 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 119 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 120 |
+
|
| 121 |
+
# arguments
|
| 122 |
+
kwargs = dict(
|
| 123 |
+
e=e0,
|
| 124 |
+
seq_lens=seq_lens,
|
| 125 |
+
grid_sizes=grid_sizes,
|
| 126 |
+
freqs=self.freqs,
|
| 127 |
+
context=context,
|
| 128 |
+
context_lens=context_lens)
|
| 129 |
+
|
| 130 |
+
# Context Parallel
|
| 131 |
+
x = torch.chunk(
|
| 132 |
+
x, get_sequence_parallel_world_size(),
|
| 133 |
+
dim=1)[get_sequence_parallel_rank()]
|
| 134 |
+
|
| 135 |
+
for block in self.blocks:
|
| 136 |
+
x = block(x, **kwargs)
|
| 137 |
+
|
| 138 |
+
# head
|
| 139 |
+
x = self.head(x, e)
|
| 140 |
+
|
| 141 |
+
# Context Parallel
|
| 142 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 143 |
+
|
| 144 |
+
# unpatchify
|
| 145 |
+
x = self.unpatchify(x, grid_sizes)
|
| 146 |
+
return [u.float() for u in x]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def usp_attn_forward(self,
|
| 150 |
+
x,
|
| 151 |
+
seq_lens,
|
| 152 |
+
grid_sizes,
|
| 153 |
+
freqs,
|
| 154 |
+
dtype=torch.bfloat16):
|
| 155 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 156 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 157 |
+
|
| 158 |
+
def half(x):
|
| 159 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 160 |
+
|
| 161 |
+
# query, key, value function
|
| 162 |
+
def qkv_fn(x):
|
| 163 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 164 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 165 |
+
v = self.v(x).view(b, s, n, d)
|
| 166 |
+
return q, k, v
|
| 167 |
+
|
| 168 |
+
q, k, v = qkv_fn(x)
|
| 169 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 170 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 171 |
+
|
| 172 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 173 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 174 |
+
# if k_lens is not None:
|
| 175 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 176 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 177 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 178 |
+
|
| 179 |
+
x = xFuserLongContextAttention()(
|
| 180 |
+
None,
|
| 181 |
+
query=half(q),
|
| 182 |
+
key=half(k),
|
| 183 |
+
value=half(v),
|
| 184 |
+
window_size=self.window_size)
|
| 185 |
+
|
| 186 |
+
# TODO: padding after attention.
|
| 187 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 188 |
+
|
| 189 |
+
# output
|
| 190 |
+
x = x.flatten(2)
|
| 191 |
+
x = self.o(x)
|
| 192 |
+
return x
|
wan/image2video.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .modules.clip import CLIPModel
|
| 21 |
+
from .modules.model import WanModel
|
| 22 |
+
from .modules.t5 import T5EncoderModel
|
| 23 |
+
from .modules.vae import WanVAE
|
| 24 |
+
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas, retrieve_timesteps)
|
| 26 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WanI2V:
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
config,
|
| 34 |
+
checkpoint_dir,
|
| 35 |
+
device_id=0,
|
| 36 |
+
rank=0,
|
| 37 |
+
t5_fsdp=False,
|
| 38 |
+
dit_fsdp=False,
|
| 39 |
+
use_usp=False,
|
| 40 |
+
t5_cpu=False,
|
| 41 |
+
init_on_cpu=True,
|
| 42 |
+
):
|
| 43 |
+
r"""
|
| 44 |
+
Initializes the image-to-video generation model components.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
config (EasyDict):
|
| 48 |
+
Object containing model parameters initialized from config.py
|
| 49 |
+
checkpoint_dir (`str`):
|
| 50 |
+
Path to directory containing model checkpoints
|
| 51 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 52 |
+
Id of target GPU device
|
| 53 |
+
rank (`int`, *optional*, defaults to 0):
|
| 54 |
+
Process rank for distributed training
|
| 55 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 56 |
+
Enable FSDP sharding for T5 model
|
| 57 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 58 |
+
Enable FSDP sharding for DiT model
|
| 59 |
+
use_usp (`bool`, *optional*, defaults to False):
|
| 60 |
+
Enable distribution strategy of USP.
|
| 61 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 62 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 63 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 64 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 65 |
+
"""
|
| 66 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 67 |
+
self.config = config
|
| 68 |
+
self.rank = rank
|
| 69 |
+
self.use_usp = use_usp
|
| 70 |
+
self.t5_cpu = t5_cpu
|
| 71 |
+
|
| 72 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 73 |
+
self.param_dtype = config.param_dtype
|
| 74 |
+
|
| 75 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 76 |
+
self.text_encoder = T5EncoderModel(
|
| 77 |
+
text_len=config.text_len,
|
| 78 |
+
dtype=config.t5_dtype,
|
| 79 |
+
device=torch.device('cpu'),
|
| 80 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 81 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 82 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.vae_stride = config.vae_stride
|
| 86 |
+
self.patch_size = config.patch_size
|
| 87 |
+
self.vae = WanVAE(
|
| 88 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 89 |
+
device=self.device)
|
| 90 |
+
|
| 91 |
+
self.clip = CLIPModel(
|
| 92 |
+
dtype=config.clip_dtype,
|
| 93 |
+
device=self.device,
|
| 94 |
+
checkpoint_path=os.path.join(checkpoint_dir,
|
| 95 |
+
config.clip_checkpoint),
|
| 96 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
| 97 |
+
|
| 98 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 99 |
+
self.model = WanModel.from_pretrained(checkpoint_dir)
|
| 100 |
+
self.model.eval().requires_grad_(False)
|
| 101 |
+
|
| 102 |
+
if t5_fsdp or dit_fsdp or use_usp:
|
| 103 |
+
init_on_cpu = False
|
| 104 |
+
|
| 105 |
+
if use_usp:
|
| 106 |
+
from xfuser.core.distributed import \
|
| 107 |
+
get_sequence_parallel_world_size
|
| 108 |
+
|
| 109 |
+
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
| 110 |
+
usp_dit_forward)
|
| 111 |
+
for block in self.model.blocks:
|
| 112 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
| 113 |
+
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
| 114 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 115 |
+
else:
|
| 116 |
+
self.sp_size = 1
|
| 117 |
+
|
| 118 |
+
if dist.is_initialized():
|
| 119 |
+
dist.barrier()
|
| 120 |
+
if dit_fsdp:
|
| 121 |
+
self.model = shard_fn(self.model)
|
| 122 |
+
else:
|
| 123 |
+
if not init_on_cpu:
|
| 124 |
+
self.model.to(self.device)
|
| 125 |
+
|
| 126 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 127 |
+
|
| 128 |
+
def generate(self,
|
| 129 |
+
input_prompt,
|
| 130 |
+
img,
|
| 131 |
+
max_area=720 * 1280,
|
| 132 |
+
frame_num=81,
|
| 133 |
+
shift=5.0,
|
| 134 |
+
sample_solver='unipc',
|
| 135 |
+
sampling_steps=40,
|
| 136 |
+
guide_scale=5.0,
|
| 137 |
+
n_prompt="",
|
| 138 |
+
seed=-1,
|
| 139 |
+
offload_model=True):
|
| 140 |
+
r"""
|
| 141 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
input_prompt (`str`):
|
| 145 |
+
Text prompt for content generation.
|
| 146 |
+
img (PIL.Image.Image):
|
| 147 |
+
Input image tensor. Shape: [3, H, W]
|
| 148 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 149 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 150 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 151 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 152 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 153 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 154 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 155 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 156 |
+
Solver used to sample the video.
|
| 157 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 158 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 159 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 160 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
| 161 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 162 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 163 |
+
seed (`int`, *optional*, defaults to -1):
|
| 164 |
+
Random seed for noise generation. If -1, use random seed
|
| 165 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 166 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
torch.Tensor:
|
| 170 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 171 |
+
- C: Color channels (3 for RGB)
|
| 172 |
+
- N: Number of frames (81)
|
| 173 |
+
- H: Frame height (from max_area)
|
| 174 |
+
- W: Frame width from max_area)
|
| 175 |
+
"""
|
| 176 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
| 177 |
+
|
| 178 |
+
F = frame_num
|
| 179 |
+
h, w = img.shape[1:]
|
| 180 |
+
aspect_ratio = h / w
|
| 181 |
+
lat_h = round(
|
| 182 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 183 |
+
self.patch_size[1] * self.patch_size[1])
|
| 184 |
+
lat_w = round(
|
| 185 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 186 |
+
self.patch_size[2] * self.patch_size[2])
|
| 187 |
+
h = lat_h * self.vae_stride[1]
|
| 188 |
+
w = lat_w * self.vae_stride[2]
|
| 189 |
+
|
| 190 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2])
|
| 191 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 192 |
+
|
| 193 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 194 |
+
seed_g = torch.Generator(device=self.device)
|
| 195 |
+
seed_g.manual_seed(seed)
|
| 196 |
+
noise = torch.randn(
|
| 197 |
+
16,
|
| 198 |
+
21,
|
| 199 |
+
lat_h,
|
| 200 |
+
lat_w,
|
| 201 |
+
dtype=torch.float32,
|
| 202 |
+
generator=seed_g,
|
| 203 |
+
device=self.device)
|
| 204 |
+
|
| 205 |
+
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
| 206 |
+
msk[:, 1:] = 0
|
| 207 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]],dim=1)
|
| 208 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 209 |
+
msk = msk.transpose(1, 2)[0]
|
| 210 |
+
|
| 211 |
+
if n_prompt == "":
|
| 212 |
+
n_prompt = self.sample_neg_prompt
|
| 213 |
+
|
| 214 |
+
# preprocess
|
| 215 |
+
if not self.t5_cpu:
|
| 216 |
+
self.text_encoder.model.to(self.device)
|
| 217 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 218 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 219 |
+
if offload_model:
|
| 220 |
+
self.text_encoder.model.cpu()
|
| 221 |
+
else:
|
| 222 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 223 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 224 |
+
context = [t.to(self.device) for t in context]
|
| 225 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 226 |
+
|
| 227 |
+
self.clip.model.to(self.device)
|
| 228 |
+
clip_context = self.clip.visual([img[:, None, :, :]])
|
| 229 |
+
if offload_model:
|
| 230 |
+
self.clip.model.cpu()
|
| 231 |
+
|
| 232 |
+
y = self.vae.encode([torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1), torch.zeros(3, 80, h, w)],dim=1).to(self.device)])[0]
|
| 233 |
+
y = torch.concat([msk, y])
|
| 234 |
+
|
| 235 |
+
@contextmanager
|
| 236 |
+
def noop_no_sync():
|
| 237 |
+
yield
|
| 238 |
+
|
| 239 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 240 |
+
|
| 241 |
+
# evaluation mode
|
| 242 |
+
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 243 |
+
|
| 244 |
+
if sample_solver == 'unipc':
|
| 245 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 246 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 247 |
+
shift=1,
|
| 248 |
+
use_dynamic_shifting=False)
|
| 249 |
+
sample_scheduler.set_timesteps(
|
| 250 |
+
sampling_steps, device=self.device, shift=shift)
|
| 251 |
+
timesteps = sample_scheduler.timesteps
|
| 252 |
+
elif sample_solver == 'dpm++':
|
| 253 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 254 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 255 |
+
shift=1,
|
| 256 |
+
use_dynamic_shifting=False)
|
| 257 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 258 |
+
timesteps, _ = retrieve_timesteps(
|
| 259 |
+
sample_scheduler,
|
| 260 |
+
device=self.device,
|
| 261 |
+
sigmas=sampling_sigmas)
|
| 262 |
+
else:
|
| 263 |
+
raise NotImplementedError("Unsupported solver.")
|
| 264 |
+
|
| 265 |
+
# sample videos
|
| 266 |
+
latent = noise
|
| 267 |
+
|
| 268 |
+
arg_c = {
|
| 269 |
+
'context': [context[0]],
|
| 270 |
+
'clip_fea': clip_context,
|
| 271 |
+
'seq_len': max_seq_len,
|
| 272 |
+
'y': [y],
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
arg_null = {
|
| 276 |
+
'context': context_null,
|
| 277 |
+
'clip_fea': clip_context,
|
| 278 |
+
'seq_len': max_seq_len,
|
| 279 |
+
'y': [y],
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
if offload_model:
|
| 283 |
+
torch.cuda.empty_cache()
|
| 284 |
+
|
| 285 |
+
self.model.to(self.device)
|
| 286 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 287 |
+
latent_model_input = [latent.to(self.device)]
|
| 288 |
+
timestep = [t]
|
| 289 |
+
|
| 290 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 291 |
+
|
| 292 |
+
noise_pred_cond = self.model(
|
| 293 |
+
latent_model_input, t=timestep, **arg_c)[0].to(
|
| 294 |
+
torch.device('cpu') if offload_model else self.device)
|
| 295 |
+
if offload_model:
|
| 296 |
+
torch.cuda.empty_cache()
|
| 297 |
+
noise_pred_uncond = self.model(
|
| 298 |
+
latent_model_input, t=timestep, **arg_null)[0].to(
|
| 299 |
+
torch.device('cpu') if offload_model else self.device)
|
| 300 |
+
if offload_model:
|
| 301 |
+
torch.cuda.empty_cache()
|
| 302 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 303 |
+
noise_pred_cond - noise_pred_uncond)
|
| 304 |
+
|
| 305 |
+
latent = latent.to(
|
| 306 |
+
torch.device('cpu') if offload_model else self.device)
|
| 307 |
+
|
| 308 |
+
temp_x0 = sample_scheduler.step(
|
| 309 |
+
noise_pred.unsqueeze(0),
|
| 310 |
+
t,
|
| 311 |
+
latent.unsqueeze(0),
|
| 312 |
+
return_dict=False,
|
| 313 |
+
generator=seed_g)[0]
|
| 314 |
+
latent = temp_x0.squeeze(0)
|
| 315 |
+
|
| 316 |
+
x0 = [latent.to(self.device)]
|
| 317 |
+
del latent_model_input, timestep
|
| 318 |
+
|
| 319 |
+
if offload_model:
|
| 320 |
+
self.model.cpu()
|
| 321 |
+
torch.cuda.empty_cache()
|
| 322 |
+
|
| 323 |
+
if self.rank == 0:
|
| 324 |
+
videos = self.vae.decode(x0)
|
| 325 |
+
|
| 326 |
+
del noise, latent
|
| 327 |
+
del sample_scheduler
|
| 328 |
+
if offload_model:
|
| 329 |
+
gc.collect()
|
| 330 |
+
torch.cuda.synchronize()
|
| 331 |
+
if dist.is_initialized():
|
| 332 |
+
dist.barrier()
|
| 333 |
+
|
| 334 |
+
return videos[0] if self.rank == 0 else None
|
wan/models/__init__.py
ADDED
|
File without changes
|
wan/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|