Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- .gitignore +54 -0
- DEPLOYMENT.md +241 -0
- LICENSE.txt +201 -0
- PROJECT_SUMMARY.md +260 -0
- QUICK_START.md +186 -0
- README.md +55 -6
- TODO.md +93 -0
- app.py +437 -0
- assets/InfiniteTalk_paper.pdf +3 -0
- assets/logo.jpg +0 -0
- assets/logo2.jpg +3 -0
- assets/pipeline.png +3 -0
- examples/multi/1-man.WAV +3 -0
- examples/multi/1-woman.WAV +3 -0
- examples/multi/ref_img.png +3 -0
- examples/multi_example_image.json +9 -0
- examples/single/1.wav +3 -0
- examples/single/ref_image.png +3 -0
- examples/single/ref_video.mp4 +3 -0
- examples/single_example_image.json +7 -0
- examples/single_example_video.json +7 -0
- packages.txt +4 -0
- requirements.txt +42 -0
- src/audio_analysis/torch_utils.py +20 -0
- src/audio_analysis/wav2vec2.py +125 -0
- src/utils.py +60 -0
- src/vram_management/__init__.py +1 -0
- src/vram_management/layers.py +243 -0
- utils/__init__.py +6 -0
- utils/gpu_manager.py +221 -0
- utils/model_loader.py +195 -0
- wan/__init__.py +6 -0
- wan/configs/__init__.py +58 -0
- wan/configs/shared_config.py +19 -0
- wan/configs/wan_i2v_14B.py +24 -0
- wan/configs/wan_multitalk_14B.py +36 -0
- wan/configs/wan_t2v_14B.py +29 -0
- wan/configs/wan_t2v_1_3B.py +29 -0
- wan/distributed/__init__.py +0 -0
- wan/distributed/fsdp.py +43 -0
- wan/distributed/xdit_context_parallel.py +550 -0
- wan/first_last_frame2video.py +377 -0
- wan/image2video.py +350 -0
- wan/modules/__init__.py +18 -0
- wan/modules/attention.py +393 -0
- wan/modules/clip.py +542 -0
- wan/modules/model.py +631 -0
- wan/modules/multitalk_model.py +824 -0
- wan/modules/t5.py +535 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/InfiniteTalk_paper.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/logo2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/multi/1-man.WAV filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/multi/1-woman.WAV filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/multi/ref_img.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/single/1.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/single/ref_image.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/single/ref_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Models and weights
|
| 24 |
+
weights/
|
| 25 |
+
*.safetensors
|
| 26 |
+
*.bin
|
| 27 |
+
*.ckpt
|
| 28 |
+
*.pth
|
| 29 |
+
|
| 30 |
+
# Temporary files
|
| 31 |
+
temp-*/
|
| 32 |
+
/tmp/
|
| 33 |
+
*.tmp
|
| 34 |
+
*.log
|
| 35 |
+
|
| 36 |
+
# IDE
|
| 37 |
+
.vscode/
|
| 38 |
+
.idea/
|
| 39 |
+
*.swp
|
| 40 |
+
*.swo
|
| 41 |
+
*~
|
| 42 |
+
|
| 43 |
+
# OS
|
| 44 |
+
.DS_Store
|
| 45 |
+
Thumbs.db
|
| 46 |
+
|
| 47 |
+
# Gradio
|
| 48 |
+
flagged/
|
| 49 |
+
|
| 50 |
+
# Environment
|
| 51 |
+
.env
|
| 52 |
+
venv/
|
| 53 |
+
ENV/
|
| 54 |
+
env/
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# InfiniteTalk - Deployment Guide
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
|
| 5 |
+
1. **HuggingFace Account**: Sign up at https://huggingface.co
|
| 6 |
+
2. **Git & Git LFS**: Install from https://git-scm.com
|
| 7 |
+
3. **HuggingFace CLI** (optional but recommended):
|
| 8 |
+
```bash
|
| 9 |
+
pip install huggingface_hub
|
| 10 |
+
huggingface-cli login
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Deployment Steps
|
| 14 |
+
|
| 15 |
+
### Option 1: Web UI (Easiest)
|
| 16 |
+
|
| 17 |
+
1. **Create New Space**
|
| 18 |
+
- Go to https://huggingface.co/new-space
|
| 19 |
+
- Space name: `infinitetalk` (or your choice)
|
| 20 |
+
- License: `apache-2.0`
|
| 21 |
+
- SDK: `Gradio`
|
| 22 |
+
- Hardware: `ZeroGPU` (free tier available!)
|
| 23 |
+
- Click "Create Space"
|
| 24 |
+
|
| 25 |
+
2. **Upload Files**
|
| 26 |
+
- Click "Files" tab in your new Space
|
| 27 |
+
- Upload all files from this directory:
|
| 28 |
+
- `README.md` (with YAML metadata)
|
| 29 |
+
- `app.py`
|
| 30 |
+
- `requirements.txt`
|
| 31 |
+
- `packages.txt`
|
| 32 |
+
- `.gitignore`
|
| 33 |
+
- `src/` folder
|
| 34 |
+
- `wan/` folder
|
| 35 |
+
- `utils/` folder
|
| 36 |
+
- `assets/` folder (optional)
|
| 37 |
+
- `examples/` folder (optional)
|
| 38 |
+
- `LICENSE.txt`
|
| 39 |
+
|
| 40 |
+
3. **Wait for Build**
|
| 41 |
+
- Space will automatically build
|
| 42 |
+
- First build takes 5-10 minutes (installing dependencies)
|
| 43 |
+
- Check "Logs" tab for build progress
|
| 44 |
+
- Watch for any error messages
|
| 45 |
+
|
| 46 |
+
4. **Test Your Space**
|
| 47 |
+
- Once built, the Space will show "Running"
|
| 48 |
+
- First generation will download models (~2-3 minutes)
|
| 49 |
+
- Try with example images/audio
|
| 50 |
+
|
| 51 |
+
### Option 2: Git (Advanced)
|
| 52 |
+
|
| 53 |
+
1. **Clone Your Space**
|
| 54 |
+
```bash
|
| 55 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 56 |
+
cd YOUR_SPACE_NAME
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. **Copy Files**
|
| 60 |
+
```bash
|
| 61 |
+
# From your local infinitetalk-hf-space directory
|
| 62 |
+
cp -r /path/to/infinitetalk-hf-space/* .
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
3. **Commit and Push**
|
| 66 |
+
```bash
|
| 67 |
+
git add .
|
| 68 |
+
git commit -m "Initial InfiniteTalk Space deployment"
|
| 69 |
+
git push
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
4. **Monitor Build**
|
| 73 |
+
- Go to your Space URL
|
| 74 |
+
- Check "Logs" for build progress
|
| 75 |
+
|
| 76 |
+
### Option 3: CLI Upload
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# From this directory
|
| 80 |
+
huggingface-cli upload YOUR_USERNAME/YOUR_SPACE_NAME . --repo-type=space
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Troubleshooting
|
| 84 |
+
|
| 85 |
+
### Build Fails with Flash-Attn Error
|
| 86 |
+
|
| 87 |
+
**Symptom**: `flash-attn` compilation fails
|
| 88 |
+
|
| 89 |
+
**Solutions**:
|
| 90 |
+
1. Try adding to `requirements.txt`:
|
| 91 |
+
```
|
| 92 |
+
flash-attn==2.7.4.post1 --no-build-isolation
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
2. Or use Dockerfile approach (create `Dockerfile`):
|
| 96 |
+
```dockerfile
|
| 97 |
+
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
|
| 98 |
+
|
| 99 |
+
RUN apt-get update && apt-get install -y \
|
| 100 |
+
python3.10 python3-pip git ffmpeg build-essential libsndfile1
|
| 101 |
+
|
| 102 |
+
WORKDIR /app
|
| 103 |
+
|
| 104 |
+
# Install PyTorch first
|
| 105 |
+
RUN pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
|
| 106 |
+
|
| 107 |
+
# Install flash-attn with pre-built wheels
|
| 108 |
+
RUN pip install flash-attn==2.7.4.post1 --no-build-isolation
|
| 109 |
+
|
| 110 |
+
# Copy and install requirements
|
| 111 |
+
COPY requirements.txt .
|
| 112 |
+
RUN pip install -r requirements.txt
|
| 113 |
+
|
| 114 |
+
# Copy application
|
| 115 |
+
COPY . .
|
| 116 |
+
|
| 117 |
+
CMD ["python3", "app.py"]
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Models Not Downloading
|
| 121 |
+
|
| 122 |
+
**Symptom**: "Model download failed" error
|
| 123 |
+
|
| 124 |
+
**Solutions**:
|
| 125 |
+
1. Check HuggingFace is not down: https://status.huggingface.co
|
| 126 |
+
2. Add HF_TOKEN secret in Space settings (for private models)
|
| 127 |
+
3. Check model repository IDs in `utils/model_loader.py`
|
| 128 |
+
|
| 129 |
+
### Out of Memory (OOM) Errors
|
| 130 |
+
|
| 131 |
+
**Symptom**: "CUDA out of memory"
|
| 132 |
+
|
| 133 |
+
**Solutions**:
|
| 134 |
+
1. Reduce resolution (use 480p instead of 720p)
|
| 135 |
+
2. Reduce diffusion steps (try 30 instead of 40)
|
| 136 |
+
3. Process shorter videos
|
| 137 |
+
4. Check `utils/gpu_manager.py` settings
|
| 138 |
+
|
| 139 |
+
### Space Stuck in "Building"
|
| 140 |
+
|
| 141 |
+
**Symptom**: Build takes >15 minutes
|
| 142 |
+
|
| 143 |
+
**Solutions**:
|
| 144 |
+
1. Check "Logs" tab for errors
|
| 145 |
+
2. Flash-attn compilation can take 10+ minutes
|
| 146 |
+
3. If timeout, try Dockerfile approach
|
| 147 |
+
4. Consider pre-built flash-attn wheels
|
| 148 |
+
|
| 149 |
+
### ZeroGPU Quota Exceeded
|
| 150 |
+
|
| 151 |
+
**Symptom**: "GPU quota exceeded"
|
| 152 |
+
|
| 153 |
+
**Solutions**:
|
| 154 |
+
1. **Free Tier**: Wait for quota to refill (1 ZeroGPU second = 30 real seconds)
|
| 155 |
+
2. **Upgrade to PRO**: $9/month for 8× quota
|
| 156 |
+
3. **Apply for Grant**: Community GPU Grant for innovative projects
|
| 157 |
+
4. Optimize generation time (reduce steps, use 480p)
|
| 158 |
+
|
| 159 |
+
## Post-Deployment
|
| 160 |
+
|
| 161 |
+
### Monitor Usage
|
| 162 |
+
- Check "Logs" tab regularly
|
| 163 |
+
- Monitor GPU quota in Space settings
|
| 164 |
+
- Watch for user error reports in Community tab
|
| 165 |
+
|
| 166 |
+
### Update Space
|
| 167 |
+
```bash
|
| 168 |
+
# Make changes locally
|
| 169 |
+
git add .
|
| 170 |
+
git commit -m "Update: [description]"
|
| 171 |
+
git push
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
Space will automatically rebuild on push.
|
| 175 |
+
|
| 176 |
+
### Add Examples
|
| 177 |
+
Upload example images and audio to `examples/` folder to help users get started quickly.
|
| 178 |
+
|
| 179 |
+
### Enable Discussions
|
| 180 |
+
In Space settings, enable "Discussions" to get user feedback.
|
| 181 |
+
|
| 182 |
+
### Apply for Community GPU Grant
|
| 183 |
+
If your Space is popular and useful:
|
| 184 |
+
1. Go to Space Settings
|
| 185 |
+
2. Click "Apply for community GPU grant"
|
| 186 |
+
3. Explain your project's value to the community
|
| 187 |
+
|
| 188 |
+
## Hardware Options
|
| 189 |
+
|
| 190 |
+
### Free ZeroGPU
|
| 191 |
+
- **Cost**: FREE
|
| 192 |
+
- **Limits**: 300s per session, 600s max quota
|
| 193 |
+
- **Best for**: Testing, light usage, demos
|
| 194 |
+
- **GPU**: H200 with 70GB VRAM
|
| 195 |
+
|
| 196 |
+
### PRO ZeroGPU
|
| 197 |
+
- **Cost**: $9/month
|
| 198 |
+
- **Benefits**: 8× quota, priority queue, 10 Spaces
|
| 199 |
+
- **Best for**: Regular usage, public demos
|
| 200 |
+
|
| 201 |
+
### Dedicated GPU (Paid)
|
| 202 |
+
- **T4 (16GB)**: $0.60/hour - Too small for InfiniteTalk
|
| 203 |
+
- **A10G (24GB)**: $1.05/hour - Minimum viable
|
| 204 |
+
- **A100 (40GB)**: $3.00/hour - Overkill but works
|
| 205 |
+
- **Best for**: Private, dedicated instances
|
| 206 |
+
|
| 207 |
+
## Performance Expectations
|
| 208 |
+
|
| 209 |
+
### First Generation
|
| 210 |
+
- Model download: 2-3 minutes
|
| 211 |
+
- Generation (10s video, 480p): 40 seconds
|
| 212 |
+
- **Total**: ~3-4 minutes
|
| 213 |
+
|
| 214 |
+
### Subsequent Generations
|
| 215 |
+
- Generation (10s video, 480p): 35-40 seconds
|
| 216 |
+
- Generation (10s video, 720p): 60-70 seconds
|
| 217 |
+
|
| 218 |
+
### Free Tier Usage
|
| 219 |
+
- ~3-5 generations per quota period (600s ZeroGPU)
|
| 220 |
+
- Quota refills gradually (1 ZeroGPU second per 30 real seconds)
|
| 221 |
+
|
| 222 |
+
## Support
|
| 223 |
+
|
| 224 |
+
- **Issues**: File at https://github.com/MeiGen-AI/InfiniteTalk/issues
|
| 225 |
+
- **Discussions**: Use Space's Community tab
|
| 226 |
+
- **HF Forums**: https://discuss.huggingface.co
|
| 227 |
+
|
| 228 |
+
## Success Checklist
|
| 229 |
+
|
| 230 |
+
- [ ] Space builds without errors
|
| 231 |
+
- [ ] Models download successfully on first run
|
| 232 |
+
- [ ] Example image-to-video generation works
|
| 233 |
+
- [ ] Example video dubbing works
|
| 234 |
+
- [ ] No OOM errors with 480p
|
| 235 |
+
- [ ] GPU memory is cleaned up between runs
|
| 236 |
+
- [ ] Gradio UI is responsive
|
| 237 |
+
- [ ] Examples are loaded and working
|
| 238 |
+
- [ ] README displays correctly
|
| 239 |
+
- [ ] Space doesn't crash after multiple uses
|
| 240 |
+
|
| 241 |
+
Good luck with your deployment! 🚀
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
PROJECT_SUMMARY.md
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# InfiniteTalk HuggingFace Space - Project Summary
|
| 2 |
+
|
| 3 |
+
## ✅ What Has Been Completed
|
| 4 |
+
|
| 5 |
+
### 1. Project Structure Setup
|
| 6 |
+
```
|
| 7 |
+
infinitetalk-hf-space/
|
| 8 |
+
├── README.md ✅ Space metadata with ZeroGPU config
|
| 9 |
+
├── app.py ✅ Gradio interface with dual tabs
|
| 10 |
+
├── requirements.txt ✅ Carefully ordered dependencies
|
| 11 |
+
├── packages.txt ✅ System dependencies (ffmpeg, etc.)
|
| 12 |
+
├── .gitignore ✅ Ignore patterns for weights/temp files
|
| 13 |
+
├── LICENSE.txt ✅ Apache 2.0 license
|
| 14 |
+
├── TODO.md ✅ Next steps for completion
|
| 15 |
+
├── DEPLOYMENT.md ✅ Deployment guide
|
| 16 |
+
├── src/ ✅ Audio analysis modules from repo
|
| 17 |
+
├── wan/ ✅ Wan model integration from repo
|
| 18 |
+
├── utils/
|
| 19 |
+
│ ├── __init__.py ✅ Module initialization
|
| 20 |
+
│ ├── model_loader.py ✅ HuggingFace Hub model manager
|
| 21 |
+
│ └── gpu_manager.py ✅ Memory monitoring & optimization
|
| 22 |
+
├── assets/ ✅ Assets from repo
|
| 23 |
+
└── examples/ ✅ Example images/videos/configs
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 2. Core Components Created
|
| 27 |
+
|
| 28 |
+
#### ✅ README.md
|
| 29 |
+
- Proper YAML frontmatter for HuggingFace Spaces
|
| 30 |
+
- `hardware: zero-gpu` configuration
|
| 31 |
+
- `sdk: gradio` specification
|
| 32 |
+
- User-facing documentation
|
| 33 |
+
- Feature descriptions and usage guide
|
| 34 |
+
|
| 35 |
+
#### ✅ app.py (Main Application)
|
| 36 |
+
- **Dual-mode Gradio interface**:
|
| 37 |
+
- Image-to-Video tab
|
| 38 |
+
- Video Dubbing tab
|
| 39 |
+
- **ZeroGPU integration**:
|
| 40 |
+
- `@spaces.GPU` decorator on generate function
|
| 41 |
+
- Dynamic duration calculation
|
| 42 |
+
- Memory optimization
|
| 43 |
+
- **User-friendly UI**:
|
| 44 |
+
- Advanced settings in collapsible accordions
|
| 45 |
+
- Progress indicators
|
| 46 |
+
- Example inputs
|
| 47 |
+
- Error handling
|
| 48 |
+
- **Input validation**:
|
| 49 |
+
- File type checking
|
| 50 |
+
- Parameter range validation
|
| 51 |
+
- Clear error messages
|
| 52 |
+
|
| 53 |
+
#### ✅ utils/model_loader.py (Model Management)
|
| 54 |
+
- **Lazy loading pattern** - models download on first use
|
| 55 |
+
- **HuggingFace Hub integration** - automatic downloads
|
| 56 |
+
- **Model caching** - uses `/data/.huggingface` for persistence
|
| 57 |
+
- **Multi-model support**:
|
| 58 |
+
- Wan2.1-I2V-14B model
|
| 59 |
+
- InfiniteTalk weights
|
| 60 |
+
- Wav2Vec2 audio encoder
|
| 61 |
+
- **Memory-mapped loading** for large models
|
| 62 |
+
- **Graceful error handling**
|
| 63 |
+
|
| 64 |
+
#### ✅ utils/gpu_manager.py (Memory Management)
|
| 65 |
+
- **Memory monitoring** - track allocated/free memory
|
| 66 |
+
- **Automatic cleanup** - garbage collection + CUDA cache clearing
|
| 67 |
+
- **Threshold alerts** - warn at 65GB/70GB limit
|
| 68 |
+
- **Optimization utilities**:
|
| 69 |
+
- FP16 conversion
|
| 70 |
+
- Memory-efficient attention detection
|
| 71 |
+
- Chunking recommendations
|
| 72 |
+
- **ZeroGPU duration calculator** - optimal `@spaces.GPU` parameters
|
| 73 |
+
|
| 74 |
+
#### ✅ requirements.txt
|
| 75 |
+
**Carefully ordered to avoid build errors:**
|
| 76 |
+
1. PyTorch (CUDA 12.1)
|
| 77 |
+
2. Flash Attention
|
| 78 |
+
3. Core ML libraries (xformers, transformers, diffusers)
|
| 79 |
+
4. Gradio + Spaces
|
| 80 |
+
5. Video/Image processing
|
| 81 |
+
6. Audio processing
|
| 82 |
+
7. Utilities
|
| 83 |
+
|
| 84 |
+
#### ✅ packages.txt
|
| 85 |
+
System dependencies:
|
| 86 |
+
- ffmpeg (video encoding)
|
| 87 |
+
- build-essential (compilation)
|
| 88 |
+
- libsndfile1 (audio)
|
| 89 |
+
- git (repo access)
|
| 90 |
+
|
| 91 |
+
### 3. Documentation Created
|
| 92 |
+
|
| 93 |
+
#### ✅ TODO.md
|
| 94 |
+
- **Critical integration steps** needed
|
| 95 |
+
- **Reference files** to study
|
| 96 |
+
- **Testing checklist**
|
| 97 |
+
- **Known issues** and solutions
|
| 98 |
+
- **Future enhancements** list
|
| 99 |
+
|
| 100 |
+
#### ✅ DEPLOYMENT.md
|
| 101 |
+
- **3 deployment methods** (Web UI, Git, CLI)
|
| 102 |
+
- **Troubleshooting guide** for common issues
|
| 103 |
+
- **Hardware options** comparison
|
| 104 |
+
- **Performance expectations**
|
| 105 |
+
- **Success checklist**
|
| 106 |
+
|
| 107 |
+
## ⚠️ What Still Needs to Be Done
|
| 108 |
+
|
| 109 |
+
### 🔴 Critical: Inference Integration
|
| 110 |
+
|
| 111 |
+
The current `app.py` has a **PLACEHOLDER** for video generation. You need to:
|
| 112 |
+
|
| 113 |
+
1. **Study the reference implementation** in cloned repo:
|
| 114 |
+
- `generate_infinitetalk.py` - main inference logic
|
| 115 |
+
- `wan/multitalk.py` - model forward pass
|
| 116 |
+
- `wan/utils/multitalk_utils.py` - utility functions
|
| 117 |
+
|
| 118 |
+
2. **Update `utils/model_loader.py`**:
|
| 119 |
+
- Replace placeholder in `load_wan_model()`
|
| 120 |
+
- Implement actual Wan model initialization
|
| 121 |
+
- Match InfiniteTalk's model loading pattern
|
| 122 |
+
|
| 123 |
+
3. **Complete `app.py` inference**:
|
| 124 |
+
- Around line 230, replace the `raise gr.Error()` placeholder
|
| 125 |
+
- Implement:
|
| 126 |
+
- Frame preprocessing
|
| 127 |
+
- Audio feature extraction (already started)
|
| 128 |
+
- Diffusion model inference
|
| 129 |
+
- Video assembly and encoding
|
| 130 |
+
- FFmpeg video+audio merging
|
| 131 |
+
|
| 132 |
+
4. **Test thoroughly**:
|
| 133 |
+
- Image-to-video generation
|
| 134 |
+
- Video dubbing
|
| 135 |
+
- Memory management
|
| 136 |
+
- Error handling
|
| 137 |
+
|
| 138 |
+
### Key Integration Points
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
# In app.py, line ~230 - Replace this:
|
| 142 |
+
raise gr.Error("Video generation logic needs to be integrated...")
|
| 143 |
+
|
| 144 |
+
# With actual InfiniteTalk inference:
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
# 1. Prepare inputs
|
| 147 |
+
# 2. Run diffusion model
|
| 148 |
+
# 3. Generate frames
|
| 149 |
+
# 4. Assemble video
|
| 150 |
+
# 5. Merge audio
|
| 151 |
+
pass
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## 📊 Current Status
|
| 155 |
+
|
| 156 |
+
| Component | Status | Notes |
|
| 157 |
+
|-----------|--------|-------|
|
| 158 |
+
| Project Structure | ✅ Complete | All directories and files created |
|
| 159 |
+
| Dependencies | ✅ Complete | requirements.txt & packages.txt ready |
|
| 160 |
+
| Model Loading | ⚠️ Template | Framework ready, needs actual implementation |
|
| 161 |
+
| GPU Management | ✅ Complete | Full monitoring and optimization |
|
| 162 |
+
| Gradio UI | ✅ Complete | Dual-tab interface with all controls |
|
| 163 |
+
| ZeroGPU Integration | ✅ Complete | Decorator and duration calculation |
|
| 164 |
+
| Inference Logic | 🔴 Incomplete | **CRITICAL: Placeholder only** |
|
| 165 |
+
| Documentation | ✅ Complete | README, TODO, DEPLOYMENT guides |
|
| 166 |
+
| Examples | ✅ Complete | Copied from original repo |
|
| 167 |
+
|
| 168 |
+
## 🚀 Next Steps
|
| 169 |
+
|
| 170 |
+
### Immediate (Required for Deployment)
|
| 171 |
+
|
| 172 |
+
1. **Complete inference integration** (see TODO.md)
|
| 173 |
+
2. **Test locally** if possible, or deploy for testing
|
| 174 |
+
3. **Debug any build errors** (especially flash-attn)
|
| 175 |
+
|
| 176 |
+
### Before Public Launch
|
| 177 |
+
|
| 178 |
+
1. **Verify model downloads** work correctly
|
| 179 |
+
2. **Test image-to-video** with multiple examples
|
| 180 |
+
3. **Test video dubbing** with multiple examples
|
| 181 |
+
4. **Confirm memory stays** under 65GB
|
| 182 |
+
5. **Ensure cleanup** works between generations
|
| 183 |
+
|
| 184 |
+
### Optional Enhancements
|
| 185 |
+
|
| 186 |
+
1. Add Text-to-Speech support (kokoro)
|
| 187 |
+
2. Add multi-person mode
|
| 188 |
+
3. Add video preview
|
| 189 |
+
4. Add progress bar for chunked processing
|
| 190 |
+
5. Add example presets
|
| 191 |
+
6. Add result gallery
|
| 192 |
+
|
| 193 |
+
## 📈 Expected Performance
|
| 194 |
+
|
| 195 |
+
### With Free ZeroGPU:
|
| 196 |
+
- **First run**: 2-3 minutes (model download)
|
| 197 |
+
- **480p generation**: ~40 seconds per 10s video
|
| 198 |
+
- **720p generation**: ~70 seconds per 10s video
|
| 199 |
+
- **Quota**: ~3-5 generations per period
|
| 200 |
+
|
| 201 |
+
### With PRO ZeroGPU ($9/month):
|
| 202 |
+
- **8× quota**: ~24-40 generations per period
|
| 203 |
+
- **Priority queue**: Faster starts
|
| 204 |
+
- **Multiple Spaces**: Up to 10 concurrent
|
| 205 |
+
|
| 206 |
+
## 🎯 Success Criteria
|
| 207 |
+
|
| 208 |
+
The Space is ready when:
|
| 209 |
+
|
| 210 |
+
- [x] All files are created and organized
|
| 211 |
+
- [x] Dependencies are properly ordered
|
| 212 |
+
- [x] ZeroGPU is configured
|
| 213 |
+
- [x] Gradio interface is functional
|
| 214 |
+
- [ ] **Inference generates actual videos** ⬅️ CRITICAL
|
| 215 |
+
- [ ] Models download automatically
|
| 216 |
+
- [ ] No OOM errors on 480p
|
| 217 |
+
- [ ] Memory cleanup works
|
| 218 |
+
- [ ] Multiple generations succeed
|
| 219 |
+
|
| 220 |
+
## 📚 Key Files to Reference
|
| 221 |
+
|
| 222 |
+
For completing the inference integration:
|
| 223 |
+
|
| 224 |
+
1. **Cloned repo's `generate_infinitetalk.py`** (main inference)
|
| 225 |
+
2. **Cloned repo's `app.py`** (original Gradio implementation)
|
| 226 |
+
3. **`wan/multitalk.py`** (model class)
|
| 227 |
+
4. **`wan/configs/*.py`** (configuration)
|
| 228 |
+
5. **`src/audio_analysis/wav2vec2.py`** (audio encoder)
|
| 229 |
+
|
| 230 |
+
## 💡 Tips
|
| 231 |
+
|
| 232 |
+
- **Start with image-to-video** - simpler than video dubbing
|
| 233 |
+
- **Test with short audio** (<10s) initially
|
| 234 |
+
- **Use 480p resolution** for faster iteration
|
| 235 |
+
- **Monitor logs** closely for errors
|
| 236 |
+
- **Check GPU memory** after each generation
|
| 237 |
+
- **Keep ZeroGPU duration** reasonable (<300s for free tier)
|
| 238 |
+
|
| 239 |
+
## 📞 Support Resources
|
| 240 |
+
|
| 241 |
+
- **InfiniteTalk GitHub**: https://github.com/MeiGen-AI/InfiniteTalk
|
| 242 |
+
- **HF Spaces Docs**: https://huggingface.co/docs/hub/spaces
|
| 243 |
+
- **ZeroGPU Docs**: https://huggingface.co/docs/hub/spaces-zerogpu
|
| 244 |
+
- **Gradio Docs**: https://gradio.app/docs
|
| 245 |
+
- **HF Forums**: https://discuss.huggingface.co
|
| 246 |
+
|
| 247 |
+
## 🎬 Ready to Deploy!
|
| 248 |
+
|
| 249 |
+
Once you complete the inference integration:
|
| 250 |
+
|
| 251 |
+
1. Review [DEPLOYMENT.md](./DEPLOYMENT.md)
|
| 252 |
+
2. Choose deployment method (Web UI recommended)
|
| 253 |
+
3. Upload all files to your HuggingFace Space
|
| 254 |
+
4. Wait for build (~5-10 minutes)
|
| 255 |
+
5. Test with examples
|
| 256 |
+
6. Share with the world! 🌟
|
| 257 |
+
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
**Note**: The framework is 90% complete. The main task remaining is integrating the actual InfiniteTalk inference logic from the original repository into the placeholder sections.
|
QUICK_START.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Start Guide
|
| 2 |
+
|
| 3 |
+
## 🚀 Deploy in 5 Minutes
|
| 4 |
+
|
| 5 |
+
### Step 1: Complete the Inference (REQUIRED)
|
| 6 |
+
⚠️ **The code has placeholders for actual video generation**
|
| 7 |
+
|
| 8 |
+
See [TODO.md](./TODO.md) for details on integrating the inference logic.
|
| 9 |
+
|
| 10 |
+
### Step 2: Create HuggingFace Space
|
| 11 |
+
|
| 12 |
+
1. Go to https://huggingface.co/new-space
|
| 13 |
+
2. Fill in:
|
| 14 |
+
- **Name**: `infinitetalk` (or your choice)
|
| 15 |
+
- **License**: `apache-2.0`
|
| 16 |
+
- **SDK**: `Gradio`
|
| 17 |
+
- **Hardware**: `ZeroGPU` ✨ (FREE tier available!)
|
| 18 |
+
3. Click **Create Space**
|
| 19 |
+
|
| 20 |
+
### Step 3: Upload Files
|
| 21 |
+
|
| 22 |
+
**Via Web UI** (easiest):
|
| 23 |
+
1. Click "Files" tab in your Space
|
| 24 |
+
2. Drag and drop all files from this directory:
|
| 25 |
+
```
|
| 26 |
+
README.md
|
| 27 |
+
app.py
|
| 28 |
+
requirements.txt
|
| 29 |
+
packages.txt
|
| 30 |
+
.gitignore
|
| 31 |
+
LICENSE.txt
|
| 32 |
+
src/ (folder)
|
| 33 |
+
wan/ (folder)
|
| 34 |
+
utils/ (folder)
|
| 35 |
+
assets/ (folder)
|
| 36 |
+
examples/ (folder)
|
| 37 |
+
```
|
| 38 |
+
3. Click "Commit changes"
|
| 39 |
+
|
| 40 |
+
**Via Git**:
|
| 41 |
+
```bash
|
| 42 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 43 |
+
cd YOUR_SPACE_NAME
|
| 44 |
+
cp -r /path/to/infinitetalk-hf-space/* .
|
| 45 |
+
git add .
|
| 46 |
+
git commit -m "Initial deployment"
|
| 47 |
+
git push
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Step 4: Wait for Build
|
| 51 |
+
|
| 52 |
+
- Build time: **5-10 minutes**
|
| 53 |
+
- Check "Logs" tab for progress
|
| 54 |
+
- Flash-attn compilation takes longest
|
| 55 |
+
|
| 56 |
+
### Step 5: Test
|
| 57 |
+
|
| 58 |
+
1. Space shows "Running" ✅
|
| 59 |
+
2. First generation downloads models (2-3 min)
|
| 60 |
+
3. Try image-to-video example
|
| 61 |
+
4. Try video dubbing example
|
| 62 |
+
|
| 63 |
+
## ⚡ Quick Commands
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
# View directory structure
|
| 67 |
+
ls -la
|
| 68 |
+
|
| 69 |
+
# Check file sizes
|
| 70 |
+
du -sh *
|
| 71 |
+
|
| 72 |
+
# Count lines of code
|
| 73 |
+
find . -name "*.py" | xargs wc -l
|
| 74 |
+
|
| 75 |
+
# Test Python syntax
|
| 76 |
+
python -m py_compile app.py
|
| 77 |
+
|
| 78 |
+
# View logs (after deployment)
|
| 79 |
+
# Go to your Space → Logs tab
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 🎯 Common Issues & Fixes
|
| 83 |
+
|
| 84 |
+
### Build Fails
|
| 85 |
+
- **Check Logs tab** for specific error
|
| 86 |
+
- **Flash-attn timeout?** Normal, wait 10-15 min
|
| 87 |
+
- **Still failing?** Try Dockerfile approach (see DEPLOYMENT.md)
|
| 88 |
+
|
| 89 |
+
### Models Don't Download
|
| 90 |
+
- Check https://status.huggingface.co
|
| 91 |
+
- Verify model repo IDs in `utils/model_loader.py`
|
| 92 |
+
- Add HF_TOKEN in Space settings if needed
|
| 93 |
+
|
| 94 |
+
### Out of Memory
|
| 95 |
+
- Use 480p instead of 720p
|
| 96 |
+
- Reduce steps to 30
|
| 97 |
+
- Process shorter videos (<10s)
|
| 98 |
+
|
| 99 |
+
### Space Stuck
|
| 100 |
+
- Refresh page
|
| 101 |
+
- Check if in queue (ZeroGPU)
|
| 102 |
+
- Wait for quota to refill
|
| 103 |
+
|
| 104 |
+
## 📊 Files Overview
|
| 105 |
+
|
| 106 |
+
| File/Folder | Purpose | Lines | Critical? |
|
| 107 |
+
|-------------|---------|-------|-----------|
|
| 108 |
+
| `README.md` | Space metadata | ~50 | ✅ Yes |
|
| 109 |
+
| `app.py` | Main application | ~350 | ✅ Yes |
|
| 110 |
+
| `requirements.txt` | Python packages | ~30 | ✅ Yes |
|
| 111 |
+
| `packages.txt` | System packages | ~4 | ✅ Yes |
|
| 112 |
+
| `utils/model_loader.py` | Model management | ~200 | ✅ Yes |
|
| 113 |
+
| `utils/gpu_manager.py` | Memory management | ~150 | ✅ Yes |
|
| 114 |
+
| `src/` | Audio analysis | - | ✅ Yes |
|
| 115 |
+
| `wan/` | Model code | - | ✅ Yes |
|
| 116 |
+
| `assets/` | UI assets | - | Optional |
|
| 117 |
+
| `examples/` | Sample data | - | Optional |
|
| 118 |
+
|
| 119 |
+
## 🔧 Pre-Deployment Checklist
|
| 120 |
+
|
| 121 |
+
- [x] All files present
|
| 122 |
+
- [x] README.md has YAML metadata
|
| 123 |
+
- [x] requirements.txt is properly ordered
|
| 124 |
+
- [x] ZeroGPU hardware configured
|
| 125 |
+
- [ ] **Inference logic integrated** ⬅️ CRITICAL
|
| 126 |
+
- [ ] Tested locally (if possible)
|
| 127 |
+
- [ ] Examples prepared
|
| 128 |
+
|
| 129 |
+
## 💰 Cost Breakdown
|
| 130 |
+
|
| 131 |
+
### Free Tier
|
| 132 |
+
- **Cost**: $0
|
| 133 |
+
- **GPU**: H200 (70GB VRAM)
|
| 134 |
+
- **Quota**: 300s per session, 600s max
|
| 135 |
+
- **Usage**: ~3-5 generations per quota
|
| 136 |
+
- **Best for**: Testing, demos, light use
|
| 137 |
+
|
| 138 |
+
### PRO Tier
|
| 139 |
+
- **Cost**: $9/month
|
| 140 |
+
- **GPU**: Same H200
|
| 141 |
+
- **Quota**: 8× more (1500s)
|
| 142 |
+
- **Spaces**: Up to 10
|
| 143 |
+
- **Best for**: Regular use, public demos
|
| 144 |
+
|
| 145 |
+
## 📈 Performance Expectations
|
| 146 |
+
|
| 147 |
+
| Task | Resolution | Time | VRAM |
|
| 148 |
+
|------|-----------|------|------|
|
| 149 |
+
| Model download | - | 2-3 min | - |
|
| 150 |
+
| 10s video | 480p | ~40s | ~38GB |
|
| 151 |
+
| 10s video | 720p | ~70s | ~55GB |
|
| 152 |
+
| 30s video | 480p | ~90s | ~45GB |
|
| 153 |
+
|
| 154 |
+
## 🎓 Learning Resources
|
| 155 |
+
|
| 156 |
+
- [HuggingFace Spaces Tutorial](https://huggingface.co/docs/hub/spaces-overview)
|
| 157 |
+
- [Gradio Documentation](https://gradio.app/docs)
|
| 158 |
+
- [ZeroGPU Guide](https://huggingface.co/docs/hub/spaces-zerogpu)
|
| 159 |
+
- [InfiniteTalk Paper](https://arxiv.org/abs/2508.14033)
|
| 160 |
+
|
| 161 |
+
## ✅ Success Checklist
|
| 162 |
+
|
| 163 |
+
After deployment:
|
| 164 |
+
|
| 165 |
+
1. [ ] Space builds successfully
|
| 166 |
+
2. [ ] No errors in Logs
|
| 167 |
+
3. [ ] UI loads properly
|
| 168 |
+
4. [ ] Models download on first run
|
| 169 |
+
5. [ ] Image-to-video works
|
| 170 |
+
6. [ ] Video dubbing works
|
| 171 |
+
7. [ ] No OOM errors
|
| 172 |
+
8. [ ] Memory cleanup works
|
| 173 |
+
9. [ ] Can run multiple generations
|
| 174 |
+
10. [ ] Results look good!
|
| 175 |
+
|
| 176 |
+
## 🆘 Need Help?
|
| 177 |
+
|
| 178 |
+
1. **Check** [TODO.md](./TODO.md) for implementation details
|
| 179 |
+
2. **Read** [DEPLOYMENT.md](./DEPLOYMENT.md) for troubleshooting
|
| 180 |
+
3. **Review** [PROJECT_SUMMARY.md](./PROJECT_SUMMARY.md) for overview
|
| 181 |
+
4. **Ask** on HuggingFace Forums: https://discuss.huggingface.co
|
| 182 |
+
5. **File issue** on InfiniteTalk GitHub: https://github.com/MeiGen-AI/InfiniteTalk
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
**Ready?** Complete the inference integration, then deploy! 🚀
|
README.md
CHANGED
|
@@ -1,12 +1,61 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: InfiniteTalk - Talking Video Generator
|
| 3 |
+
emoji: 🎬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.0.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
hardware: zero-gpu
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# InfiniteTalk - Talking Video Generator
|
| 15 |
+
|
| 16 |
+
Generate realistic talking head videos with accurate lip-sync from images or dub existing videos with new audio!
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **Image-to-Video**: Transform a static portrait image into a talking video using audio input
|
| 21 |
+
- **Video Dubbing**: Re-sync an existing video with new audio while maintaining natural head movements and expressions
|
| 22 |
+
- **High Quality**: 480p and 720p resolution support with advanced lip-sync technology
|
| 23 |
+
- **Unlimited Length**: Support for videos of any duration through chunked processing
|
| 24 |
+
|
| 25 |
+
## How It Works
|
| 26 |
+
|
| 27 |
+
InfiniteTalk uses the state-of-the-art Wan2.1 diffusion model combined with specialized audio conditioning to create photorealistic talking videos. The system synchronizes:
|
| 28 |
+
|
| 29 |
+
- Lip movements with audio
|
| 30 |
+
- Head pose and rotations
|
| 31 |
+
- Facial expressions
|
| 32 |
+
- Body posture
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
### Image-to-Video
|
| 37 |
+
1. Upload a portrait image (clear face visibility recommended)
|
| 38 |
+
2. Upload an audio file or use the example
|
| 39 |
+
3. Adjust parameters if needed
|
| 40 |
+
4. Click Generate
|
| 41 |
+
|
| 42 |
+
### Video Dubbing
|
| 43 |
+
1. Upload a video with a visible face
|
| 44 |
+
2. Upload new audio to dub over it
|
| 45 |
+
3. Adjust parameters if needed
|
| 46 |
+
4. Click Generate
|
| 47 |
+
|
| 48 |
+
## Parameters
|
| 49 |
+
|
| 50 |
+
- **Resolution**: Choose between 480p (faster) or 720p (higher quality)
|
| 51 |
+
- **Diffusion Steps**: More steps = higher quality but slower (20-50 recommended)
|
| 52 |
+
- **Audio Guide Scale**: Controls audio influence on generation (2-4 recommended)
|
| 53 |
+
- **Seed**: For reproducible results
|
| 54 |
+
|
| 55 |
+
## Credits
|
| 56 |
+
|
| 57 |
+
Built on [InfiniteTalk](https://github.com/MeiGen-AI/InfiniteTalk) by MeiGen-AI.
|
| 58 |
+
|
| 59 |
+
## License
|
| 60 |
+
|
| 61 |
+
Apache 2.0 - See LICENSE.txt for details
|
TODO.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# InfiniteTalk Space - TODO for Completion
|
| 2 |
+
|
| 3 |
+
## Critical: Inference Integration Needed
|
| 4 |
+
|
| 5 |
+
The current `app.py` has a **placeholder** for the actual video generation logic. To complete the implementation, you need to integrate the actual InfiniteTalk inference code.
|
| 6 |
+
|
| 7 |
+
### Steps to Complete:
|
| 8 |
+
|
| 9 |
+
#### 1. Review Reference Implementation
|
| 10 |
+
Check `temp-infinitetalk/generate_infinitetalk.py` for the actual inference logic, particularly:
|
| 11 |
+
- How the Wan model is initialized
|
| 12 |
+
- How audio conditioning works
|
| 13 |
+
- How frames are generated
|
| 14 |
+
- How the final video is assembled
|
| 15 |
+
|
| 16 |
+
#### 2. Update `utils/model_loader.py`
|
| 17 |
+
The `load_wan_model()` method currently has a placeholder. Replace it with actual Wan model loading:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
def load_wan_model(self, size="infinitetalk-480", device="cuda"):
|
| 21 |
+
# Replace the placeholder with actual Wan model initialization
|
| 22 |
+
# Reference: temp-infinitetalk/generate_infinitetalk.py lines ~200-300
|
| 23 |
+
pass
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
#### 3. Integrate Inference in `app.py`
|
| 27 |
+
In the `generate_video()` function around line 170, replace the placeholder section with:
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
# Current placeholder (line ~230):
|
| 31 |
+
raise gr.Error("Video generation logic needs to be integrated...")
|
| 32 |
+
|
| 33 |
+
# Replace with actual inference code from generate_infinitetalk.py
|
| 34 |
+
# Key steps:
|
| 35 |
+
# 1. Load/prepare input frames
|
| 36 |
+
# 2. Extract and process audio features
|
| 37 |
+
# 3. Run diffusion model with audio conditioning
|
| 38 |
+
# 4. Post-process and save video
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
#### 4. Audio Feature Extraction
|
| 42 |
+
Ensure the audio feature extraction matches InfiniteTalk's requirements:
|
| 43 |
+
- Check if Wav2Vec2 preprocessing is correct
|
| 44 |
+
- Verify audio normalization parameters
|
| 45 |
+
- Confirm sample rate (16kHz)
|
| 46 |
+
|
| 47 |
+
#### 5. Video Assembly
|
| 48 |
+
Implement the video assembly logic:
|
| 49 |
+
- Frame generation loop
|
| 50 |
+
- Streaming/chunking for long videos
|
| 51 |
+
- FFmpeg video encoding
|
| 52 |
+
- Audio merging
|
| 53 |
+
|
| 54 |
+
### Reference Files to Study:
|
| 55 |
+
|
| 56 |
+
1. **`temp-infinitetalk/generate_infinitetalk.py`** - Main inference logic
|
| 57 |
+
2. **`temp-infinitetalk/app.py`** - Original Gradio implementation
|
| 58 |
+
3. **`wan/multitalk.py`** - Model inference
|
| 59 |
+
4. **`wan/utils/multitalk_utils.py`** - Utility functions
|
| 60 |
+
|
| 61 |
+
### Testing Checklist:
|
| 62 |
+
|
| 63 |
+
- [ ] Models download correctly from HuggingFace Hub
|
| 64 |
+
- [ ] Image input is properly processed
|
| 65 |
+
- [ ] Video input is properly processed
|
| 66 |
+
- [ ] Audio features are extracted correctly
|
| 67 |
+
- [ ] Video generation completes without OOM errors
|
| 68 |
+
- [ ] Output video has correct lip-sync
|
| 69 |
+
- [ ] Memory is cleaned up after generation
|
| 70 |
+
- [ ] Multiple generations work in sequence
|
| 71 |
+
|
| 72 |
+
## Optional Enhancements (Future):
|
| 73 |
+
|
| 74 |
+
- [ ] Add Text-to-Speech (kokoro integration)
|
| 75 |
+
- [ ] Add multi-person mode support
|
| 76 |
+
- [ ] Add progress bar for long videos
|
| 77 |
+
- [ ] Add video preview before generation
|
| 78 |
+
- [ ] Add batch processing
|
| 79 |
+
- [ ] Add custom LoRA support
|
| 80 |
+
- [ ] Add video quality comparison slider
|
| 81 |
+
|
| 82 |
+
## Known Issues:
|
| 83 |
+
|
| 84 |
+
1. **Flash-attn compilation**: May fail on some systems
|
| 85 |
+
- Solution: Use pre-built wheels or Dockerfile
|
| 86 |
+
2. **Model download time**: First run takes 2-3 minutes
|
| 87 |
+
- Expected behavior with 15GB+ models
|
| 88 |
+
3. **ZeroGPU timeout**: Long videos may exceed quota
|
| 89 |
+
- Solution: Implement chunking or recommend shorter inputs
|
| 90 |
+
|
| 91 |
+
## Deployment Notes:
|
| 92 |
+
|
| 93 |
+
See `DEPLOYMENT.md` for step-by-step deployment instructions.
|
app.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
InfiniteTalk - Talking Video Generator
|
| 3 |
+
Gradio Space with ZeroGPU support
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import random
|
| 9 |
+
import logging
|
| 10 |
+
import warnings
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import spaces
|
| 17 |
+
|
| 18 |
+
# Suppress warnings
|
| 19 |
+
warnings.filterwarnings('ignore')
|
| 20 |
+
|
| 21 |
+
# Setup logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
# Add current directory to path
|
| 26 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 27 |
+
|
| 28 |
+
# Import utilities
|
| 29 |
+
from utils.model_loader import ModelManager
|
| 30 |
+
from utils.gpu_manager import gpu_manager
|
| 31 |
+
|
| 32 |
+
# Import InfiniteTalk modules
|
| 33 |
+
import wan
|
| 34 |
+
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
| 35 |
+
from wan.utils.utils import cache_image, cache_video, is_video
|
| 36 |
+
from wan.utils.multitalk_utils import save_video_ffmpeg
|
| 37 |
+
|
| 38 |
+
# Audio processing
|
| 39 |
+
import librosa
|
| 40 |
+
import soundfile as sf
|
| 41 |
+
import pyloudnorm as pyln
|
| 42 |
+
from transformers import Wav2Vec2FeatureExtractor
|
| 43 |
+
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
|
| 44 |
+
|
| 45 |
+
# Image/Video processing
|
| 46 |
+
from PIL import Image
|
| 47 |
+
from einops import rearrange
|
| 48 |
+
|
| 49 |
+
# Global variables
|
| 50 |
+
model_manager = None
|
| 51 |
+
models_loaded = False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def initialize_models(progress=gr.Progress()):
|
| 55 |
+
"""Initialize models on first use"""
|
| 56 |
+
global model_manager, models_loaded
|
| 57 |
+
|
| 58 |
+
if models_loaded:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
progress(0.1, desc="Initializing model manager...")
|
| 63 |
+
model_manager = ModelManager()
|
| 64 |
+
|
| 65 |
+
progress(0.3, desc="Downloading models (first time only - may take 2-3 minutes)...")
|
| 66 |
+
|
| 67 |
+
# Download models (lazy loading - they'll be loaded on first inference)
|
| 68 |
+
model_manager.get_wan_model_path()
|
| 69 |
+
model_manager.get_infinitetalk_weights_path()
|
| 70 |
+
model_manager.get_wav2vec_model_path()
|
| 71 |
+
|
| 72 |
+
models_loaded = True
|
| 73 |
+
progress(1.0, desc="Models ready!")
|
| 74 |
+
logger.info("Models initialized successfully")
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error initializing models: {e}")
|
| 78 |
+
raise gr.Error(f"Failed to initialize models: {str(e)}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def process_audio(audio_path, target_sr=16000):
|
| 82 |
+
"""
|
| 83 |
+
Process audio file for InfiniteTalk
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
audio_path: Path to audio file
|
| 87 |
+
target_sr: Target sample rate
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Processed audio array and sample rate
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
# Load audio
|
| 94 |
+
audio, sr = librosa.load(audio_path, sr=None)
|
| 95 |
+
|
| 96 |
+
# Resample if needed
|
| 97 |
+
if sr != target_sr:
|
| 98 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
|
| 99 |
+
sr = target_sr
|
| 100 |
+
|
| 101 |
+
# Normalize loudness
|
| 102 |
+
meter = pyln.Meter(sr)
|
| 103 |
+
loudness = meter.integrated_loudness(audio)
|
| 104 |
+
audio = pyln.normalize.loudness(audio, loudness, -20.0)
|
| 105 |
+
|
| 106 |
+
# Ensure mono
|
| 107 |
+
if len(audio.shape) > 1:
|
| 108 |
+
audio = np.mean(audio, axis=1)
|
| 109 |
+
|
| 110 |
+
return audio, sr
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"Error processing audio: {e}")
|
| 114 |
+
raise gr.Error(f"Audio processing failed: {str(e)}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def validate_inputs(image_or_video, audio, resolution, steps):
|
| 118 |
+
"""Validate user inputs"""
|
| 119 |
+
errors = []
|
| 120 |
+
|
| 121 |
+
if image_or_video is None:
|
| 122 |
+
errors.append("Please upload an image or video")
|
| 123 |
+
|
| 124 |
+
if audio is None:
|
| 125 |
+
errors.append("Please upload an audio file")
|
| 126 |
+
|
| 127 |
+
if resolution not in ["480p", "720p"]:
|
| 128 |
+
errors.append("Invalid resolution selected")
|
| 129 |
+
|
| 130 |
+
if not (20 <= steps <= 50):
|
| 131 |
+
errors.append("Steps must be between 20 and 50")
|
| 132 |
+
|
| 133 |
+
if errors:
|
| 134 |
+
raise gr.Error(" | ".join(errors))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@spaces.GPU(duration=180)
|
| 138 |
+
def generate_video(
|
| 139 |
+
image_or_video,
|
| 140 |
+
audio_file,
|
| 141 |
+
resolution="480p",
|
| 142 |
+
steps=40,
|
| 143 |
+
audio_guide_scale=3.0,
|
| 144 |
+
seed=-1,
|
| 145 |
+
progress=gr.Progress()
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Generate talking video from image or dub existing video
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
image_or_video: Input image or video file
|
| 152 |
+
audio_file: Audio file for lip-sync
|
| 153 |
+
resolution: Output resolution (480p or 720p)
|
| 154 |
+
steps: Number of diffusion steps
|
| 155 |
+
audio_guide_scale: Audio conditioning strength
|
| 156 |
+
seed: Random seed for reproducibility
|
| 157 |
+
progress: Gradio progress tracker
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Path to generated video
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
# Initialize models if needed
|
| 164 |
+
if not models_loaded:
|
| 165 |
+
initialize_models(progress)
|
| 166 |
+
|
| 167 |
+
# Validate inputs
|
| 168 |
+
validate_inputs(image_or_video, audio_file, resolution, steps)
|
| 169 |
+
|
| 170 |
+
# GPU memory check
|
| 171 |
+
gpu_manager.print_memory_usage("Initial - ")
|
| 172 |
+
|
| 173 |
+
progress(0.1, desc="Processing audio...")
|
| 174 |
+
|
| 175 |
+
# Process audio
|
| 176 |
+
audio, sr = process_audio(audio_file)
|
| 177 |
+
audio_duration = len(audio) / sr
|
| 178 |
+
logger.info(f"Audio duration: {audio_duration:.2f}s")
|
| 179 |
+
|
| 180 |
+
# Calculate ZeroGPU duration
|
| 181 |
+
zerogpu_duration = gpu_manager.calculate_duration_for_zerogpu(
|
| 182 |
+
audio_duration, resolution
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
progress(0.2, desc="Loading models...")
|
| 186 |
+
|
| 187 |
+
# Load models
|
| 188 |
+
size = f"infinitetalk-{resolution.replace('p', '')}"
|
| 189 |
+
|
| 190 |
+
# Load Wan model
|
| 191 |
+
wan_model = model_manager.load_wan_model(size=size, device="cuda")
|
| 192 |
+
|
| 193 |
+
# Load audio encoder
|
| 194 |
+
audio_encoder, feature_extractor = model_manager.load_audio_encoder(device="cuda")
|
| 195 |
+
|
| 196 |
+
gpu_manager.print_memory_usage("After model loading - ")
|
| 197 |
+
|
| 198 |
+
progress(0.3, desc="Processing input...")
|
| 199 |
+
|
| 200 |
+
# Determine if input is image or video
|
| 201 |
+
is_input_video = is_video(image_or_video)
|
| 202 |
+
|
| 203 |
+
if is_input_video:
|
| 204 |
+
logger.info("Processing video dubbing...")
|
| 205 |
+
input_frames = cache_video(image_or_video)
|
| 206 |
+
else:
|
| 207 |
+
logger.info("Processing image-to-video...")
|
| 208 |
+
input_image = Image.open(image_or_video).convert("RGB")
|
| 209 |
+
input_frames = [input_image]
|
| 210 |
+
|
| 211 |
+
progress(0.4, desc="Extracting audio features...")
|
| 212 |
+
|
| 213 |
+
# Extract audio features
|
| 214 |
+
audio_features = feature_extractor(
|
| 215 |
+
audio,
|
| 216 |
+
sampling_rate=sr,
|
| 217 |
+
return_tensors="pt"
|
| 218 |
+
).input_values
|
| 219 |
+
|
| 220 |
+
audio_features = audio_features.to("cuda")
|
| 221 |
+
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
audio_embeddings = audio_encoder(audio_features).last_hidden_state
|
| 224 |
+
|
| 225 |
+
gpu_manager.print_memory_usage("After audio processing - ")
|
| 226 |
+
|
| 227 |
+
progress(0.5, desc="Generating video (this may take a minute)...")
|
| 228 |
+
|
| 229 |
+
# Set random seed
|
| 230 |
+
if seed == -1:
|
| 231 |
+
seed = random.randint(0, 99999999)
|
| 232 |
+
|
| 233 |
+
torch.manual_seed(seed)
|
| 234 |
+
if torch.cuda.is_available():
|
| 235 |
+
torch.cuda.manual_seed(seed)
|
| 236 |
+
|
| 237 |
+
# Generate video
|
| 238 |
+
# This is a placeholder for the actual inference logic
|
| 239 |
+
# The actual implementation would call wan_model.generate() with proper parameters
|
| 240 |
+
|
| 241 |
+
output_path = f"/tmp/output_{seed}.mp4"
|
| 242 |
+
|
| 243 |
+
# Simplified inference call (replace with actual InfiniteTalk logic)
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
# Parameters
|
| 246 |
+
generation_args = {
|
| 247 |
+
"input_frames": input_frames,
|
| 248 |
+
"audio_embeddings": audio_embeddings,
|
| 249 |
+
"num_steps": steps,
|
| 250 |
+
"audio_guide_scale": audio_guide_scale,
|
| 251 |
+
"size": size,
|
| 252 |
+
"seed": seed,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Call model inference (placeholder)
|
| 256 |
+
# output_frames = wan_model.generate(**generation_args)
|
| 257 |
+
|
| 258 |
+
# For now, just create a dummy output to test the pipeline
|
| 259 |
+
# In production, this would be replaced with actual video generation
|
| 260 |
+
logger.info(f"Generating {resolution} video with {steps} steps...")
|
| 261 |
+
|
| 262 |
+
# Placeholder: copy input as output for testing
|
| 263 |
+
import shutil
|
| 264 |
+
if is_input_video:
|
| 265 |
+
shutil.copy(image_or_video, output_path)
|
| 266 |
+
else:
|
| 267 |
+
# Create a short video from the image
|
| 268 |
+
# This is just for testing - replace with actual generation
|
| 269 |
+
logger.warning("Placeholder: actual video generation not implemented yet")
|
| 270 |
+
raise gr.Error(
|
| 271 |
+
"Video generation logic needs to be integrated. "
|
| 272 |
+
"This is a template - please integrate the actual InfiniteTalk "
|
| 273 |
+
"inference code from generate_infinitetalk.py"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
progress(0.9, desc="Finalizing...")
|
| 277 |
+
|
| 278 |
+
# Cleanup
|
| 279 |
+
gpu_manager.cleanup()
|
| 280 |
+
|
| 281 |
+
progress(1.0, desc="Complete!")
|
| 282 |
+
|
| 283 |
+
logger.info(f"Video generated successfully: {output_path}")
|
| 284 |
+
return output_path
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Error generating video: {e}")
|
| 288 |
+
gpu_manager.cleanup()
|
| 289 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def create_interface():
|
| 293 |
+
"""Create Gradio interface"""
|
| 294 |
+
|
| 295 |
+
with gr.Blocks(title="InfiniteTalk - Talking Video Generator", theme=gr.themes.Soft()) as demo:
|
| 296 |
+
gr.Markdown("""
|
| 297 |
+
# 🎬 InfiniteTalk - Talking Video Generator
|
| 298 |
+
|
| 299 |
+
Generate realistic talking head videos with accurate lip-sync from images or dub existing videos with new audio!
|
| 300 |
+
|
| 301 |
+
**Note**: First generation may take 2-3 minutes while models download. Subsequent generations are much faster (~40s for 10s video).
|
| 302 |
+
""")
|
| 303 |
+
|
| 304 |
+
with gr.Tabs():
|
| 305 |
+
# Tab 1: Image-to-Video
|
| 306 |
+
with gr.Tab("📸 Image-to-Video"):
|
| 307 |
+
gr.Markdown("Transform a static portrait into a talking video")
|
| 308 |
+
|
| 309 |
+
with gr.Row():
|
| 310 |
+
with gr.Column():
|
| 311 |
+
image_input = gr.Image(
|
| 312 |
+
type="filepath",
|
| 313 |
+
label="Upload Portrait Image",
|
| 314 |
+
info="Clear face visibility recommended"
|
| 315 |
+
)
|
| 316 |
+
audio_input_i2v = gr.Audio(
|
| 317 |
+
type="filepath",
|
| 318 |
+
label="Upload Audio",
|
| 319 |
+
info="MP3, WAV, or FLAC"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 323 |
+
resolution_i2v = gr.Radio(
|
| 324 |
+
choices=["480p", "720p"],
|
| 325 |
+
value="480p",
|
| 326 |
+
label="Resolution",
|
| 327 |
+
info="480p is faster, 720p is higher quality"
|
| 328 |
+
)
|
| 329 |
+
steps_i2v = gr.Slider(
|
| 330 |
+
minimum=20,
|
| 331 |
+
maximum=50,
|
| 332 |
+
value=40,
|
| 333 |
+
step=1,
|
| 334 |
+
label="Diffusion Steps",
|
| 335 |
+
info="More steps = higher quality but slower"
|
| 336 |
+
)
|
| 337 |
+
audio_scale_i2v = gr.Slider(
|
| 338 |
+
minimum=1.0,
|
| 339 |
+
maximum=5.0,
|
| 340 |
+
value=3.0,
|
| 341 |
+
step=0.5,
|
| 342 |
+
label="Audio Guide Scale",
|
| 343 |
+
info="Controls audio influence (2-4 recommended)"
|
| 344 |
+
)
|
| 345 |
+
seed_i2v = gr.Number(
|
| 346 |
+
value=-1,
|
| 347 |
+
label="Seed",
|
| 348 |
+
info="-1 for random"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
generate_btn_i2v = gr.Button("🎬 Generate Video", variant="primary", size="lg")
|
| 352 |
+
|
| 353 |
+
with gr.Column():
|
| 354 |
+
output_video_i2v = gr.Video(label="Generated Video")
|
| 355 |
+
gr.Markdown("**💡 Tip**: Use high-quality portrait images with clear facial features for best results")
|
| 356 |
+
|
| 357 |
+
generate_btn_i2v.click(
|
| 358 |
+
fn=generate_video,
|
| 359 |
+
inputs=[image_input, audio_input_i2v, resolution_i2v, steps_i2v, audio_scale_i2v, seed_i2v],
|
| 360 |
+
outputs=output_video_i2v
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Tab 2: Video Dubbing
|
| 364 |
+
with gr.Tab("🎥 Video Dubbing"):
|
| 365 |
+
gr.Markdown("Dub an existing video with new audio while maintaining natural movements")
|
| 366 |
+
|
| 367 |
+
with gr.Row():
|
| 368 |
+
with gr.Column():
|
| 369 |
+
video_input = gr.Video(
|
| 370 |
+
label="Upload Video",
|
| 371 |
+
info="Video with visible face"
|
| 372 |
+
)
|
| 373 |
+
audio_input_v2v = gr.Audio(
|
| 374 |
+
type="filepath",
|
| 375 |
+
label="Upload New Audio",
|
| 376 |
+
info="MP3, WAV, or FLAC"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 380 |
+
resolution_v2v = gr.Radio(
|
| 381 |
+
choices=["480p", "720p"],
|
| 382 |
+
value="480p",
|
| 383 |
+
label="Resolution"
|
| 384 |
+
)
|
| 385 |
+
steps_v2v = gr.Slider(
|
| 386 |
+
minimum=20,
|
| 387 |
+
maximum=50,
|
| 388 |
+
value=40,
|
| 389 |
+
step=1,
|
| 390 |
+
label="Diffusion Steps"
|
| 391 |
+
)
|
| 392 |
+
audio_scale_v2v = gr.Slider(
|
| 393 |
+
minimum=1.0,
|
| 394 |
+
maximum=5.0,
|
| 395 |
+
value=3.0,
|
| 396 |
+
step=0.5,
|
| 397 |
+
label="Audio Guide Scale"
|
| 398 |
+
)
|
| 399 |
+
seed_v2v = gr.Number(
|
| 400 |
+
value=-1,
|
| 401 |
+
label="Seed"
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
generate_btn_v2v = gr.Button("🎬 Generate Dubbed Video", variant="primary", size="lg")
|
| 405 |
+
|
| 406 |
+
with gr.Column():
|
| 407 |
+
output_video_v2v = gr.Video(label="Dubbed Video")
|
| 408 |
+
gr.Markdown("**💡 Tip**: For best results, use videos with consistent face visibility throughout")
|
| 409 |
+
|
| 410 |
+
generate_btn_v2v.click(
|
| 411 |
+
fn=generate_video,
|
| 412 |
+
inputs=[video_input, audio_input_v2v, resolution_v2v, steps_v2v, audio_scale_v2v, seed_v2v],
|
| 413 |
+
outputs=output_video_v2v
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Footer
|
| 417 |
+
gr.Markdown("""
|
| 418 |
+
---
|
| 419 |
+
### About
|
| 420 |
+
Powered by [InfiniteTalk](https://github.com/MeiGen-AI/InfiniteTalk) - Apache 2.0 License
|
| 421 |
+
|
| 422 |
+
**Free Tier Usage**: ~3-5 generations per quota period on free ZeroGPU
|
| 423 |
+
|
| 424 |
+
💡 **Tips**:
|
| 425 |
+
- First generation downloads models (~15GB) and may take 2-3 minutes
|
| 426 |
+
- Use 480p for faster generation (~40s for 10s video)
|
| 427 |
+
- Use 720p for higher quality (slower but better results)
|
| 428 |
+
- Clear, well-lit images produce the best results
|
| 429 |
+
""")
|
| 430 |
+
|
| 431 |
+
return demo
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
demo = create_interface()
|
| 436 |
+
demo.queue(max_size=10)
|
| 437 |
+
demo.launch()
|
assets/InfiniteTalk_paper.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dcefdbb788a7f10aa941adf642a8f511fbb99b874e8dd271b9067caefa6b41b2
|
| 3 |
+
size 13015738
|
assets/logo.jpg
ADDED
|
assets/logo2.jpg
ADDED
|
Git LFS Details
|
assets/pipeline.png
ADDED
|
Git LFS Details
|
examples/multi/1-man.WAV
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d304fd88850d6673649d1844db2894e03bf5a775123048eebcb01ab3b79bff5e
|
| 3 |
+
size 1503276
|
examples/multi/1-woman.WAV
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e1ebd7ae1587ebc7f0986f8b61e7fcc99c6fb57fbb15ab9373968e701afc8bf
|
| 3 |
+
size 1503276
|
examples/multi/ref_img.png
ADDED
|
Git LFS Details
|
examples/multi_example_image.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "In a casual, intimate setting, a man and a woman are engaged in a heartfelt conversation inside a car. The man, sporting a denim jacket over a blue shirt, sits attentively with a seatbelt fastened, his gaze fixed on the woman beside him. The woman, wearing a black tank top and a denim jacket draped over her shoulders, smiles warmly, her eyes reflecting genuine interest and connection. The car's interior, with its beige seats and simple design, provides a backdrop that emphasizes their interaction. The scene captures a moment of shared understanding and connection, set against the soft, diffused light of an overcast day. A medium shot from a slightly angled perspective, focusing on their expressions and body language.",
|
| 3 |
+
"cond_video": "examples/multi/ref_img.png",
|
| 4 |
+
"audio_type": "para",
|
| 5 |
+
"cond_audio": {
|
| 6 |
+
"person1": "examples/multi/1-man.WAV",
|
| 7 |
+
"person2": "examples/multi/1-woman.WAV"
|
| 8 |
+
}
|
| 9 |
+
}
|
examples/single/1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba2733897f561f747e6508734bff4eeee29d0a73638e5c39c0c0b806701d4e8b
|
| 3 |
+
size 1888320
|
examples/single/ref_image.png
ADDED
|
Git LFS Details
|
examples/single/ref_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cb07cbfa63576d8b06eb2954cc56d1b089764f0a9428da867348810d6cb9071
|
| 3 |
+
size 843790
|
examples/single_example_image.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "A woman is passionately singing into a professional microphone in a recording studio. She wears large black headphones and a dark cardigan over a gray top. Her long, wavy brown hair frames her face as she looks slightly upwards, her mouth open mid-song. The studio is equipped with various audio equipment, including a mixing console and a keyboard, with soundproofing panels on the walls. The lighting is warm and focused on her, creating a professional and intimate atmosphere. A close-up shot captures her expressive performance.",
|
| 3 |
+
"cond_video": "examples/single/ref_image.png",
|
| 4 |
+
"cond_audio": {
|
| 5 |
+
"person1": "examples/single/1.wav"
|
| 6 |
+
}
|
| 7 |
+
}
|
examples/single_example_video.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "A man is talking",
|
| 3 |
+
"cond_video": "examples/single/ref_video.mp4",
|
| 4 |
+
"cond_audio": {
|
| 5 |
+
"person1": "examples/single/1.wav"
|
| 6 |
+
}
|
| 7 |
+
}
|
packages.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
build-essential
|
| 3 |
+
libsndfile1
|
| 4 |
+
git
|
requirements.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. PyTorch FIRST (CUDA 12.1 compatible with ZeroGPU)
|
| 2 |
+
torch==2.4.1
|
| 3 |
+
torchvision==0.19.1
|
| 4 |
+
torchaudio==2.4.1
|
| 5 |
+
|
| 6 |
+
# 2. Flash Attention (may need --no-build-isolation)
|
| 7 |
+
flash-attn==2.7.4.post1
|
| 8 |
+
|
| 9 |
+
# 3. Core ML libraries
|
| 10 |
+
xformers==0.0.28
|
| 11 |
+
transformers>=4.49.0
|
| 12 |
+
tokenizers>=0.20.3
|
| 13 |
+
diffusers>=0.31.0
|
| 14 |
+
accelerate>=1.1.1
|
| 15 |
+
einops
|
| 16 |
+
|
| 17 |
+
# 4. Gradio and Spaces
|
| 18 |
+
gradio>=5.0.0
|
| 19 |
+
spaces
|
| 20 |
+
|
| 21 |
+
# 5. Video/Image processing
|
| 22 |
+
opencv-python-headless>=4.9.0.80
|
| 23 |
+
moviepy==1.0.3
|
| 24 |
+
imageio
|
| 25 |
+
imageio-ffmpeg
|
| 26 |
+
scikit-image
|
| 27 |
+
decord
|
| 28 |
+
scenedetect
|
| 29 |
+
|
| 30 |
+
# 6. Audio processing
|
| 31 |
+
librosa
|
| 32 |
+
soundfile
|
| 33 |
+
pyloudnorm
|
| 34 |
+
|
| 35 |
+
# 7. Utilities
|
| 36 |
+
tqdm
|
| 37 |
+
numpy>=1.23.5,<2
|
| 38 |
+
easydict
|
| 39 |
+
ftfy
|
| 40 |
+
loguru
|
| 41 |
+
optimum-quanto==0.2.6
|
| 42 |
+
xfuser>=0.4.1
|
src/audio_analysis/torch_utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
| 6 |
+
lengths = lengths.to(torch.long)
|
| 7 |
+
if max_len is None:
|
| 8 |
+
max_len = torch.max(lengths).item()
|
| 9 |
+
|
| 10 |
+
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
| 11 |
+
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
| 12 |
+
|
| 13 |
+
return mask
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def linear_interpolation(features, seq_len):
|
| 17 |
+
features = features.transpose(1, 2)
|
| 18 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
| 19 |
+
return output_features.transpose(1, 2)
|
| 20 |
+
|
src/audio_analysis/wav2vec2.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Wav2Vec2Config, Wav2Vec2Model
|
| 2 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 3 |
+
|
| 4 |
+
from src.audio_analysis.torch_utils import linear_interpolation
|
| 5 |
+
|
| 6 |
+
# the implementation of Wav2Vec2Model is borrowed from
|
| 7 |
+
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
| 8 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
| 9 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
| 10 |
+
def __init__(self, config: Wav2Vec2Config):
|
| 11 |
+
super().__init__(config)
|
| 12 |
+
|
| 13 |
+
def forward(
|
| 14 |
+
self,
|
| 15 |
+
input_values,
|
| 16 |
+
seq_len,
|
| 17 |
+
attention_mask=None,
|
| 18 |
+
mask_time_indices=None,
|
| 19 |
+
output_attentions=None,
|
| 20 |
+
output_hidden_states=None,
|
| 21 |
+
return_dict=None,
|
| 22 |
+
):
|
| 23 |
+
self.config.output_attentions = True
|
| 24 |
+
|
| 25 |
+
output_hidden_states = (
|
| 26 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 27 |
+
)
|
| 28 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 29 |
+
|
| 30 |
+
extract_features = self.feature_extractor(input_values)
|
| 31 |
+
extract_features = extract_features.transpose(1, 2)
|
| 32 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 33 |
+
|
| 34 |
+
if attention_mask is not None:
|
| 35 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 36 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 37 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 41 |
+
hidden_states = self._mask_hidden_states(
|
| 42 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
encoder_outputs = self.encoder(
|
| 46 |
+
hidden_states,
|
| 47 |
+
attention_mask=attention_mask,
|
| 48 |
+
output_attentions=output_attentions,
|
| 49 |
+
output_hidden_states=output_hidden_states,
|
| 50 |
+
return_dict=return_dict,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
hidden_states = encoder_outputs[0]
|
| 54 |
+
|
| 55 |
+
if self.adapter is not None:
|
| 56 |
+
hidden_states = self.adapter(hidden_states)
|
| 57 |
+
|
| 58 |
+
if not return_dict:
|
| 59 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 60 |
+
return BaseModelOutput(
|
| 61 |
+
last_hidden_state=hidden_states,
|
| 62 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 63 |
+
attentions=encoder_outputs.attentions,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def feature_extract(
|
| 68 |
+
self,
|
| 69 |
+
input_values,
|
| 70 |
+
seq_len,
|
| 71 |
+
):
|
| 72 |
+
extract_features = self.feature_extractor(input_values)
|
| 73 |
+
extract_features = extract_features.transpose(1, 2)
|
| 74 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 75 |
+
|
| 76 |
+
return extract_features
|
| 77 |
+
|
| 78 |
+
def encode(
|
| 79 |
+
self,
|
| 80 |
+
extract_features,
|
| 81 |
+
attention_mask=None,
|
| 82 |
+
mask_time_indices=None,
|
| 83 |
+
output_attentions=None,
|
| 84 |
+
output_hidden_states=None,
|
| 85 |
+
return_dict=None,
|
| 86 |
+
):
|
| 87 |
+
self.config.output_attentions = True
|
| 88 |
+
|
| 89 |
+
output_hidden_states = (
|
| 90 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 91 |
+
)
|
| 92 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 93 |
+
|
| 94 |
+
if attention_mask is not None:
|
| 95 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 96 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 97 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 102 |
+
hidden_states = self._mask_hidden_states(
|
| 103 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
encoder_outputs = self.encoder(
|
| 107 |
+
hidden_states,
|
| 108 |
+
attention_mask=attention_mask,
|
| 109 |
+
output_attentions=output_attentions,
|
| 110 |
+
output_hidden_states=output_hidden_states,
|
| 111 |
+
return_dict=return_dict,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
hidden_states = encoder_outputs[0]
|
| 115 |
+
|
| 116 |
+
if self.adapter is not None:
|
| 117 |
+
hidden_states = self.adapter(hidden_states)
|
| 118 |
+
|
| 119 |
+
if not return_dict:
|
| 120 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 121 |
+
return BaseModelOutput(
|
| 122 |
+
last_hidden_state=hidden_states,
|
| 123 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 124 |
+
attentions=encoder_outputs.attentions,
|
| 125 |
+
)
|
src/utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
|
| 7 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 8 |
+
if include_buffers:
|
| 9 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
| 10 |
+
|
| 11 |
+
def register_empty_parameter(module, name, param):
|
| 12 |
+
old_register_parameter(module, name, param)
|
| 13 |
+
if param is not None:
|
| 14 |
+
param_cls = type(module._parameters[name])
|
| 15 |
+
kwargs = module._parameters[name].__dict__
|
| 16 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 17 |
+
module._parameters[name] = param_cls(
|
| 18 |
+
module._parameters[name].to(device), **kwargs
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 22 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 23 |
+
if buffer is not None:
|
| 24 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 25 |
+
|
| 26 |
+
def patch_tensor_constructor(fn):
|
| 27 |
+
def wrapper(*args, **kwargs):
|
| 28 |
+
kwargs["device"] = device
|
| 29 |
+
return fn(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
return wrapper
|
| 32 |
+
|
| 33 |
+
if include_buffers:
|
| 34 |
+
tensor_constructors_to_patch = {
|
| 35 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 36 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 37 |
+
}
|
| 38 |
+
else:
|
| 39 |
+
tensor_constructors_to_patch = {}
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 43 |
+
if include_buffers:
|
| 44 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
| 45 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 46 |
+
setattr(
|
| 47 |
+
torch,
|
| 48 |
+
torch_function_name,
|
| 49 |
+
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
| 50 |
+
)
|
| 51 |
+
yield
|
| 52 |
+
finally:
|
| 53 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
| 54 |
+
if include_buffers:
|
| 55 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
| 56 |
+
for (
|
| 57 |
+
torch_function_name,
|
| 58 |
+
old_torch_function,
|
| 59 |
+
) in tensor_constructors_to_patch.items():
|
| 60 |
+
setattr(torch, torch_function_name, old_torch_function)
|
src/vram_management/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .layers import *
|
src/vram_management/layers.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from src.utils import init_weights_on_device
|
| 6 |
+
import optimum.quanto.nn.qlinear as qlinear
|
| 7 |
+
|
| 8 |
+
def cast_to(weight, dtype, device):
|
| 9 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 10 |
+
r.copy_(weight)
|
| 11 |
+
return r
|
| 12 |
+
|
| 13 |
+
def cast_to_device(weight, device):
|
| 14 |
+
if hasattr(weight, '__class__') and 'optimum.quanto' in str(weight.__class__):
|
| 15 |
+
return weight.to(device)
|
| 16 |
+
else:
|
| 17 |
+
r = torch.empty_like(weight, device=device)
|
| 18 |
+
r.copy_(weight)
|
| 19 |
+
return r
|
| 20 |
+
|
| 21 |
+
class AutoWrappedModule(torch.nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
module: torch.nn.Module,
|
| 25 |
+
offload_dtype,
|
| 26 |
+
offload_device,
|
| 27 |
+
onload_dtype,
|
| 28 |
+
onload_device,
|
| 29 |
+
computation_dtype,
|
| 30 |
+
computation_device,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
| 34 |
+
self.offload_dtype = offload_dtype
|
| 35 |
+
self.offload_device = offload_device
|
| 36 |
+
self.onload_dtype = onload_dtype
|
| 37 |
+
self.onload_device = onload_device
|
| 38 |
+
self.computation_dtype = computation_dtype
|
| 39 |
+
self.computation_device = computation_device
|
| 40 |
+
self.state = 0
|
| 41 |
+
|
| 42 |
+
def offload(self):
|
| 43 |
+
if self.state == 1 and (
|
| 44 |
+
self.offload_dtype != self.onload_dtype
|
| 45 |
+
or self.offload_device != self.onload_device
|
| 46 |
+
):
|
| 47 |
+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 48 |
+
self.state = 0
|
| 49 |
+
|
| 50 |
+
def onload(self):
|
| 51 |
+
if self.state == 0 and (
|
| 52 |
+
self.offload_dtype != self.onload_dtype
|
| 53 |
+
or self.offload_device != self.onload_device
|
| 54 |
+
):
|
| 55 |
+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 56 |
+
self.state = 1
|
| 57 |
+
|
| 58 |
+
def forward(self, *args, **kwargs):
|
| 59 |
+
if (
|
| 60 |
+
self.onload_dtype == self.computation_dtype
|
| 61 |
+
and self.onload_device == self.computation_device
|
| 62 |
+
):
|
| 63 |
+
module = self.module
|
| 64 |
+
else:
|
| 65 |
+
module = copy.deepcopy(self.module).to(
|
| 66 |
+
dtype=self.computation_dtype, device=self.computation_device
|
| 67 |
+
)
|
| 68 |
+
return module(*args, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AutoWrappedQLinear(qlinear.QLinear):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
module: qlinear.QLinear,
|
| 76 |
+
offload_dtype,
|
| 77 |
+
offload_device,
|
| 78 |
+
onload_dtype,
|
| 79 |
+
onload_device,
|
| 80 |
+
computation_dtype,
|
| 81 |
+
computation_device,
|
| 82 |
+
):
|
| 83 |
+
with init_weights_on_device(device=torch.device("meta")):
|
| 84 |
+
super().__init__(
|
| 85 |
+
in_features=module.in_features,
|
| 86 |
+
out_features=module.out_features,
|
| 87 |
+
bias=module.bias is not None,
|
| 88 |
+
device=offload_device,
|
| 89 |
+
)
|
| 90 |
+
self.weight = module.weight
|
| 91 |
+
self.bias = module.bias
|
| 92 |
+
self.offload_device = offload_device
|
| 93 |
+
|
| 94 |
+
self.onload_device = onload_device
|
| 95 |
+
self.computation_device = computation_device
|
| 96 |
+
self.state = 0
|
| 97 |
+
|
| 98 |
+
def offload(self):
|
| 99 |
+
if self.state == 1 and (
|
| 100 |
+
self.offload_device != self.onload_device
|
| 101 |
+
):
|
| 102 |
+
self.to(device=self.offload_device)
|
| 103 |
+
self.state = 0
|
| 104 |
+
|
| 105 |
+
def onload(self):
|
| 106 |
+
if self.state == 0 and (
|
| 107 |
+
self.offload_device != self.onload_device
|
| 108 |
+
):
|
| 109 |
+
self.to(device=self.onload_device)
|
| 110 |
+
self.state = 1
|
| 111 |
+
|
| 112 |
+
def forward(self, x, *args, **kwargs):
|
| 113 |
+
if (
|
| 114 |
+
self.onload_device == self.computation_device
|
| 115 |
+
):
|
| 116 |
+
|
| 117 |
+
return torch.nn.functional.linear(x, self.weight, bias=self.bias)
|
| 118 |
+
else:
|
| 119 |
+
|
| 120 |
+
qweight = cast_to_device(self.weight, self.computation_device)
|
| 121 |
+
bias = (
|
| 122 |
+
None
|
| 123 |
+
if self.bias is None
|
| 124 |
+
else cast_to_device(self.bias, self.computation_device)
|
| 125 |
+
)
|
| 126 |
+
return torch.nn.functional.linear(x, qweight, bias)
|
| 127 |
+
|
| 128 |
+
class AutoWrappedLinear(torch.nn.Linear):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
module: torch.nn.Linear,
|
| 132 |
+
offload_dtype,
|
| 133 |
+
offload_device,
|
| 134 |
+
onload_dtype,
|
| 135 |
+
onload_device,
|
| 136 |
+
computation_dtype,
|
| 137 |
+
computation_device,
|
| 138 |
+
):
|
| 139 |
+
with init_weights_on_device(device=torch.device("meta")):
|
| 140 |
+
super().__init__(
|
| 141 |
+
in_features=module.in_features,
|
| 142 |
+
out_features=module.out_features,
|
| 143 |
+
bias=module.bias is not None,
|
| 144 |
+
dtype=offload_dtype,
|
| 145 |
+
device=offload_device,
|
| 146 |
+
)
|
| 147 |
+
self.weight = module.weight
|
| 148 |
+
self.bias = module.bias
|
| 149 |
+
self.offload_dtype = offload_dtype
|
| 150 |
+
self.offload_device = offload_device
|
| 151 |
+
self.onload_dtype = onload_dtype
|
| 152 |
+
self.onload_device = onload_device
|
| 153 |
+
self.computation_dtype = computation_dtype
|
| 154 |
+
self.computation_device = computation_device
|
| 155 |
+
self.state = 0
|
| 156 |
+
|
| 157 |
+
def offload(self):
|
| 158 |
+
if self.state == 1 and (
|
| 159 |
+
self.offload_dtype != self.onload_dtype
|
| 160 |
+
or self.offload_device != self.onload_device
|
| 161 |
+
):
|
| 162 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 163 |
+
self.state = 0
|
| 164 |
+
|
| 165 |
+
def onload(self):
|
| 166 |
+
if self.state == 0 and (
|
| 167 |
+
self.offload_dtype != self.onload_dtype
|
| 168 |
+
or self.offload_device != self.onload_device
|
| 169 |
+
):
|
| 170 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 171 |
+
self.state = 1
|
| 172 |
+
|
| 173 |
+
def forward(self, x, *args, **kwargs):
|
| 174 |
+
if (
|
| 175 |
+
self.onload_dtype == self.computation_dtype
|
| 176 |
+
and self.onload_device == self.computation_device
|
| 177 |
+
):
|
| 178 |
+
weight, bias = self.weight, self.bias
|
| 179 |
+
else:
|
| 180 |
+
weight = cast_to(
|
| 181 |
+
self.weight, self.computation_dtype, self.computation_device
|
| 182 |
+
)
|
| 183 |
+
bias = (
|
| 184 |
+
None
|
| 185 |
+
if self.bias is None
|
| 186 |
+
else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
| 187 |
+
)
|
| 188 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def enable_vram_management_recursively(
|
| 192 |
+
model: torch.nn.Module,
|
| 193 |
+
module_map: dict,
|
| 194 |
+
module_config: dict,
|
| 195 |
+
max_num_param=None,
|
| 196 |
+
overflow_module_config: dict = None,
|
| 197 |
+
total_num_param=0,
|
| 198 |
+
):
|
| 199 |
+
for name, module in model.named_children():
|
| 200 |
+
for source_module, target_module in module_map.items():
|
| 201 |
+
if isinstance(module, source_module):
|
| 202 |
+
num_param = sum(p.numel() for p in module.parameters())
|
| 203 |
+
# print(str(module) + ':' + str(num_param))
|
| 204 |
+
if (
|
| 205 |
+
max_num_param is not None
|
| 206 |
+
and total_num_param + num_param > max_num_param
|
| 207 |
+
):
|
| 208 |
+
# print(str(module) + '-->\t\t num:' + str(num_param) + "\t total:" + str(total_num_param))
|
| 209 |
+
module_config_ = overflow_module_config
|
| 210 |
+
else:
|
| 211 |
+
module_config_ = module_config
|
| 212 |
+
module_ = target_module(module, **module_config_)
|
| 213 |
+
setattr(model, name, module_)
|
| 214 |
+
total_num_param += num_param
|
| 215 |
+
break
|
| 216 |
+
else:
|
| 217 |
+
total_num_param = enable_vram_management_recursively(
|
| 218 |
+
module,
|
| 219 |
+
module_map,
|
| 220 |
+
module_config,
|
| 221 |
+
max_num_param,
|
| 222 |
+
overflow_module_config,
|
| 223 |
+
total_num_param,
|
| 224 |
+
)
|
| 225 |
+
return total_num_param
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def enable_vram_management(
|
| 229 |
+
model: torch.nn.Module,
|
| 230 |
+
module_map: dict,
|
| 231 |
+
module_config: dict,
|
| 232 |
+
max_num_param=None,
|
| 233 |
+
overflow_module_config: dict = None,
|
| 234 |
+
):
|
| 235 |
+
enable_vram_management_recursively(
|
| 236 |
+
model,
|
| 237 |
+
module_map,
|
| 238 |
+
module_config,
|
| 239 |
+
max_num_param,
|
| 240 |
+
overflow_module_config,
|
| 241 |
+
total_num_param=0,
|
| 242 |
+
)
|
| 243 |
+
model.vram_management_enabled = True
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility modules for InfiniteTalk Space"""
|
| 2 |
+
|
| 3 |
+
from .model_loader import ModelManager
|
| 4 |
+
from .gpu_manager import GPUManager, gpu_manager
|
| 5 |
+
|
| 6 |
+
__all__ = ["ModelManager", "GPUManager", "gpu_manager"]
|
utils/gpu_manager.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPU Memory Manager for InfiniteTalk
|
| 3 |
+
Handles memory monitoring, cleanup, and optimization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GPUManager:
|
| 15 |
+
"""Manages GPU memory usage and optimization"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, max_memory_gb=65):
|
| 18 |
+
"""
|
| 19 |
+
Initialize GPU Manager
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
max_memory_gb: Maximum memory threshold in GB (default 65GB for 70GB H200)
|
| 23 |
+
"""
|
| 24 |
+
self.max_memory_bytes = max_memory_gb * 1024 ** 3
|
| 25 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
|
| 27 |
+
def get_memory_usage(self):
|
| 28 |
+
"""
|
| 29 |
+
Get current GPU memory usage
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
dict with allocated, reserved, and free memory in GB
|
| 33 |
+
"""
|
| 34 |
+
if not torch.cuda.is_available():
|
| 35 |
+
return {"allocated": 0, "reserved": 0, "free": 0}
|
| 36 |
+
|
| 37 |
+
allocated = torch.cuda.memory_allocated() / 1024 ** 3
|
| 38 |
+
reserved = torch.cuda.memory_reserved() / 1024 ** 3
|
| 39 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
|
| 40 |
+
free = total - allocated
|
| 41 |
+
|
| 42 |
+
return {
|
| 43 |
+
"allocated": round(allocated, 2),
|
| 44 |
+
"reserved": round(reserved, 2),
|
| 45 |
+
"free": round(free, 2),
|
| 46 |
+
"total": round(total, 2)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def print_memory_usage(self, prefix=""):
|
| 50 |
+
"""Print current memory usage"""
|
| 51 |
+
usage = self.get_memory_usage()
|
| 52 |
+
logger.info(
|
| 53 |
+
f"{prefix}GPU Memory - "
|
| 54 |
+
f"Allocated: {usage['allocated']}GB, "
|
| 55 |
+
f"Reserved: {usage['reserved']}GB, "
|
| 56 |
+
f"Free: {usage['free']}GB"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def check_memory_threshold(self):
|
| 60 |
+
"""
|
| 61 |
+
Check if memory usage exceeds threshold
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
bool: True if within safe limits, False if exceeded
|
| 65 |
+
"""
|
| 66 |
+
if not torch.cuda.is_available():
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
allocated = torch.cuda.memory_allocated()
|
| 70 |
+
|
| 71 |
+
if allocated > self.max_memory_bytes:
|
| 72 |
+
logger.warning(
|
| 73 |
+
f"Memory threshold exceeded! "
|
| 74 |
+
f"Allocated: {allocated / 1024**3:.2f}GB, "
|
| 75 |
+
f"Threshold: {self.max_memory_bytes / 1024**3:.2f}GB"
|
| 76 |
+
)
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
def cleanup(self):
|
| 82 |
+
"""Perform garbage collection and CUDA cache cleanup"""
|
| 83 |
+
import gc
|
| 84 |
+
|
| 85 |
+
gc.collect()
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
torch.cuda.empty_cache()
|
| 88 |
+
torch.cuda.synchronize()
|
| 89 |
+
|
| 90 |
+
logger.info("GPU memory cleaned up")
|
| 91 |
+
self.print_memory_usage("After cleanup - ")
|
| 92 |
+
|
| 93 |
+
def optimize_model_for_inference(self, model):
|
| 94 |
+
"""
|
| 95 |
+
Apply optimizations to model for inference
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
model: PyTorch model to optimize
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Optimized model
|
| 102 |
+
"""
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
+
# Enable gradient checkpointing if available
|
| 106 |
+
if hasattr(model, "enable_gradient_checkpointing"):
|
| 107 |
+
model.enable_gradient_checkpointing()
|
| 108 |
+
|
| 109 |
+
# Use FP16 for inference to save memory
|
| 110 |
+
if torch.cuda.is_available() and hasattr(model, "half"):
|
| 111 |
+
logger.info("Converting model to FP16")
|
| 112 |
+
model = model.half()
|
| 113 |
+
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
def enable_memory_efficient_attention(self):
|
| 117 |
+
"""Enable memory-efficient attention mechanisms"""
|
| 118 |
+
try:
|
| 119 |
+
import xformers
|
| 120 |
+
|
| 121 |
+
logger.info("xformers available - memory efficient attention enabled")
|
| 122 |
+
return True
|
| 123 |
+
except ImportError:
|
| 124 |
+
logger.warning("xformers not available - using standard attention")
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
def estimate_inference_memory(self, resolution="480p", duration_seconds=10):
|
| 128 |
+
"""
|
| 129 |
+
Estimate memory requirements for inference
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
resolution: Video resolution (480p or 720p)
|
| 133 |
+
duration_seconds: Video duration in seconds
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Estimated memory in GB
|
| 137 |
+
"""
|
| 138 |
+
base_memory = 20 # Base model memory
|
| 139 |
+
|
| 140 |
+
if resolution == "720p":
|
| 141 |
+
per_second_memory = 1.5
|
| 142 |
+
else: # 480p
|
| 143 |
+
per_second_memory = 0.8
|
| 144 |
+
|
| 145 |
+
estimated = base_memory + (duration_seconds * per_second_memory)
|
| 146 |
+
|
| 147 |
+
logger.info(
|
| 148 |
+
f"Estimated memory for {resolution} video ({duration_seconds}s): "
|
| 149 |
+
f"{estimated:.2f}GB"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return estimated
|
| 153 |
+
|
| 154 |
+
def should_use_chunking(self, video_duration, resolution="480p"):
|
| 155 |
+
"""
|
| 156 |
+
Determine if chunked processing should be used
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
video_duration: Duration in seconds
|
| 160 |
+
resolution: Video resolution
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
bool: True if chunking recommended
|
| 164 |
+
"""
|
| 165 |
+
estimated_memory = self.estimate_inference_memory(resolution, video_duration)
|
| 166 |
+
|
| 167 |
+
# Use chunking if estimated memory exceeds 50GB
|
| 168 |
+
return estimated_memory > 50
|
| 169 |
+
|
| 170 |
+
def get_optimal_chunk_size(self, resolution="480p"):
|
| 171 |
+
"""
|
| 172 |
+
Get optimal chunk size for video processing
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
resolution: Video resolution
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Optimal chunk size in seconds
|
| 179 |
+
"""
|
| 180 |
+
if resolution == "720p":
|
| 181 |
+
return 10 # 10 second chunks for 720p
|
| 182 |
+
else:
|
| 183 |
+
return 15 # 15 second chunks for 480p
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def calculate_duration_for_zerogpu(video_duration, resolution="480p"):
|
| 187 |
+
"""
|
| 188 |
+
Calculate ZeroGPU duration parameter
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
video_duration: Duration of video in seconds
|
| 192 |
+
resolution: Video resolution
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Recommended duration for @spaces.GPU decorator
|
| 196 |
+
"""
|
| 197 |
+
base_time = 60 # Base time for model loading
|
| 198 |
+
|
| 199 |
+
# Processing time per second of video
|
| 200 |
+
if resolution == "720p":
|
| 201 |
+
processing_rate = 3.5
|
| 202 |
+
else: # 480p
|
| 203 |
+
processing_rate = 2.5
|
| 204 |
+
|
| 205 |
+
# Add safety margin of 1.2x
|
| 206 |
+
estimated_time = base_time + (video_duration * processing_rate)
|
| 207 |
+
duration = int(estimated_time * 1.2)
|
| 208 |
+
|
| 209 |
+
# Cap at 300 seconds for free tier (300s ZeroGPU = 10 min real time)
|
| 210 |
+
duration = min(duration, 300)
|
| 211 |
+
|
| 212 |
+
logger.info(
|
| 213 |
+
f"Calculated ZeroGPU duration: {duration}s for "
|
| 214 |
+
f"{video_duration}s {resolution} video"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return duration
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Global instance
|
| 221 |
+
gpu_manager = GPUManager()
|
utils/model_loader.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Manager for InfiniteTalk
|
| 3 |
+
Handles lazy loading and caching of models from HuggingFace Hub
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ModelManager:
|
| 17 |
+
"""Manages model loading and caching"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, cache_dir=None):
|
| 20 |
+
"""
|
| 21 |
+
Initialize Model Manager
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
cache_dir: Directory for caching models. Defaults to HF_HOME or /data/.huggingface
|
| 25 |
+
"""
|
| 26 |
+
if cache_dir is None:
|
| 27 |
+
cache_dir = os.environ.get("HF_HOME", "/data/.huggingface")
|
| 28 |
+
|
| 29 |
+
self.cache_dir = Path(cache_dir)
|
| 30 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
self.models = {}
|
| 33 |
+
self.model_paths = {
|
| 34 |
+
"wan": None,
|
| 35 |
+
"infinitetalk": None,
|
| 36 |
+
"wav2vec": None
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def download_model(self, repo_id, subfolder=None, filename=None):
|
| 40 |
+
"""
|
| 41 |
+
Download model from HuggingFace Hub with caching
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
repo_id: HuggingFace repository ID (e.g., "Kijai/WanVideo_comfy")
|
| 45 |
+
subfolder: Optional subfolder within the repository
|
| 46 |
+
filename: Optional specific file to download
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Path to downloaded model directory
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
logger.info(f"Downloading {repo_id} from HuggingFace Hub...")
|
| 53 |
+
|
| 54 |
+
download_kwargs = {
|
| 55 |
+
"repo_id": repo_id,
|
| 56 |
+
"cache_dir": str(self.cache_dir),
|
| 57 |
+
"resume_download": True,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
if subfolder:
|
| 61 |
+
download_kwargs["allow_patterns"] = f"{subfolder}/*"
|
| 62 |
+
if filename:
|
| 63 |
+
download_kwargs["allow_patterns"] = filename
|
| 64 |
+
|
| 65 |
+
model_path = snapshot_download(**download_kwargs)
|
| 66 |
+
|
| 67 |
+
if subfolder:
|
| 68 |
+
model_path = os.path.join(model_path, subfolder)
|
| 69 |
+
|
| 70 |
+
logger.info(f"Model downloaded successfully to {model_path}")
|
| 71 |
+
return model_path
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Error downloading model {repo_id}: {e}")
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
def get_wan_model_path(self):
|
| 78 |
+
"""Get or download Wan2.1 I2V model"""
|
| 79 |
+
if self.model_paths["wan"] is None:
|
| 80 |
+
logger.info("Downloading Wan2.1-I2V-14B-480P model...")
|
| 81 |
+
# This will download the full model - adjust repo_id based on actual HF location
|
| 82 |
+
self.model_paths["wan"] = self.download_model(
|
| 83 |
+
repo_id="Kijai/WanVideo_comfy",
|
| 84 |
+
subfolder="wan2_1_i2v_14B_480P"
|
| 85 |
+
)
|
| 86 |
+
return self.model_paths["wan"]
|
| 87 |
+
|
| 88 |
+
def get_infinitetalk_weights_path(self):
|
| 89 |
+
"""Get or download InfiniteTalk weights"""
|
| 90 |
+
if self.model_paths["infinitetalk"] is None:
|
| 91 |
+
logger.info("Downloading InfiniteTalk weights...")
|
| 92 |
+
self.model_paths["infinitetalk"] = self.download_model(
|
| 93 |
+
repo_id="MeiGen-AI/InfiniteTalk",
|
| 94 |
+
subfolder="single"
|
| 95 |
+
)
|
| 96 |
+
return self.model_paths["infinitetalk"]
|
| 97 |
+
|
| 98 |
+
def get_wav2vec_model_path(self):
|
| 99 |
+
"""Get or download Wav2Vec2 audio encoder"""
|
| 100 |
+
if self.model_paths["wav2vec"] is None:
|
| 101 |
+
logger.info("Downloading Wav2Vec2 audio encoder...")
|
| 102 |
+
self.model_paths["wav2vec"] = self.download_model(
|
| 103 |
+
repo_id="TencentGameMate/chinese-wav2vec2-base"
|
| 104 |
+
)
|
| 105 |
+
return self.model_paths["wav2vec"]
|
| 106 |
+
|
| 107 |
+
def load_wan_model(self, size="infinitetalk-480", device="cuda"):
|
| 108 |
+
"""
|
| 109 |
+
Load Wan model for inference
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
size: Model size configuration
|
| 113 |
+
device: Device to load model on
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Loaded model
|
| 117 |
+
"""
|
| 118 |
+
if "wan_model" not in self.models:
|
| 119 |
+
import wan
|
| 120 |
+
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
| 121 |
+
|
| 122 |
+
model_path = self.get_wan_model_path()
|
| 123 |
+
infinitetalk_path = self.get_infinitetalk_weights_path()
|
| 124 |
+
|
| 125 |
+
logger.info(f"Loading Wan model from {model_path}...")
|
| 126 |
+
|
| 127 |
+
# Initialize model based on InfiniteTalk's approach
|
| 128 |
+
task = "infinitetalk-14B"
|
| 129 |
+
args_dict = {
|
| 130 |
+
"ckpt_dir": model_path,
|
| 131 |
+
"infinitetalk_dir": os.path.join(infinitetalk_path, "infinitetalk.safetensors"),
|
| 132 |
+
"task": task,
|
| 133 |
+
"size": size,
|
| 134 |
+
"sample_steps": 40,
|
| 135 |
+
"sample_shift": 7 if size == "infinitetalk-480" else 11,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Create a simple namespace object for args
|
| 139 |
+
class Args:
|
| 140 |
+
def __init__(self, **kwargs):
|
| 141 |
+
self.__dict__.update(kwargs)
|
| 142 |
+
|
| 143 |
+
args = Args(**args_dict)
|
| 144 |
+
|
| 145 |
+
# Load model (simplified - actual loading would use wan.load_model())
|
| 146 |
+
# This is a placeholder - actual implementation would call the wan library
|
| 147 |
+
model = wan.WanModel(args)
|
| 148 |
+
model.to(device)
|
| 149 |
+
model.eval()
|
| 150 |
+
|
| 151 |
+
self.models["wan_model"] = model
|
| 152 |
+
logger.info("Wan model loaded successfully")
|
| 153 |
+
|
| 154 |
+
return self.models["wan_model"]
|
| 155 |
+
|
| 156 |
+
def load_audio_encoder(self, device="cuda"):
|
| 157 |
+
"""
|
| 158 |
+
Load Wav2Vec2 audio encoder
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
device: Device to load model on
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Audio encoder model and feature extractor
|
| 165 |
+
"""
|
| 166 |
+
if "audio_encoder" not in self.models:
|
| 167 |
+
from transformers import Wav2Vec2FeatureExtractor
|
| 168 |
+
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
|
| 169 |
+
|
| 170 |
+
wav2vec_path = self.get_wav2vec_model_path()
|
| 171 |
+
|
| 172 |
+
logger.info(f"Loading audio encoder from {wav2vec_path}...")
|
| 173 |
+
|
| 174 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_path)
|
| 175 |
+
audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_path)
|
| 176 |
+
audio_encoder.to(device)
|
| 177 |
+
audio_encoder.eval()
|
| 178 |
+
|
| 179 |
+
self.models["audio_encoder"] = (audio_encoder, feature_extractor)
|
| 180 |
+
logger.info("Audio encoder loaded successfully")
|
| 181 |
+
|
| 182 |
+
return self.models["audio_encoder"]
|
| 183 |
+
|
| 184 |
+
def unload_model(self, model_name):
|
| 185 |
+
"""Unload a specific model to free memory"""
|
| 186 |
+
if model_name in self.models:
|
| 187 |
+
del self.models[model_name]
|
| 188 |
+
torch.cuda.empty_cache()
|
| 189 |
+
logger.info(f"Unloaded {model_name}")
|
| 190 |
+
|
| 191 |
+
def clear_all(self):
|
| 192 |
+
"""Unload all models"""
|
| 193 |
+
self.models.clear()
|
| 194 |
+
torch.cuda.empty_cache()
|
| 195 |
+
logger.info("All models unloaded")
|
wan/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import configs, distributed, modules
|
| 2 |
+
from .first_last_frame2video import WanFLF2V
|
| 3 |
+
from .image2video import WanI2V
|
| 4 |
+
from .text2video import WanT2V
|
| 5 |
+
from .vace import WanVace, WanVaceMP
|
| 6 |
+
from .multitalk import InfiniteTalkPipeline
|
wan/configs/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
|
| 7 |
+
from .wan_i2v_14B import i2v_14B
|
| 8 |
+
from .wan_t2v_1_3B import t2v_1_3B
|
| 9 |
+
from .wan_t2v_14B import t2v_14B
|
| 10 |
+
from .wan_multitalk_14B import multitalk_14B
|
| 11 |
+
|
| 12 |
+
# the config of t2i_14B is the same as t2v_14B
|
| 13 |
+
t2i_14B = copy.deepcopy(t2v_14B)
|
| 14 |
+
t2i_14B.__name__ = 'Config: Wan T2I 14B'
|
| 15 |
+
|
| 16 |
+
# the config of flf2v_14B is the same as i2v_14B
|
| 17 |
+
flf2v_14B = copy.deepcopy(i2v_14B)
|
| 18 |
+
flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
|
| 19 |
+
flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
|
| 20 |
+
|
| 21 |
+
WAN_CONFIGS = {
|
| 22 |
+
't2v-14B': t2v_14B,
|
| 23 |
+
't2v-1.3B': t2v_1_3B,
|
| 24 |
+
'i2v-14B': i2v_14B,
|
| 25 |
+
't2i-14B': t2i_14B,
|
| 26 |
+
'flf2v-14B': flf2v_14B,
|
| 27 |
+
'vace-1.3B': t2v_1_3B,
|
| 28 |
+
'vace-14B': t2v_14B,
|
| 29 |
+
'infinitetalk-14B': multitalk_14B,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
SIZE_CONFIGS = {
|
| 33 |
+
'720*1280': (720, 1280),
|
| 34 |
+
'1280*720': (1280, 720),
|
| 35 |
+
'480*832': (480, 832),
|
| 36 |
+
'832*480': (832, 480),
|
| 37 |
+
'1024*1024': (1024, 1024),
|
| 38 |
+
'infinitetalk-480': (640, 640),
|
| 39 |
+
'infinitetalk-720': (960, 960),
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
MAX_AREA_CONFIGS = {
|
| 43 |
+
'720*1280': 720 * 1280,
|
| 44 |
+
'1280*720': 1280 * 720,
|
| 45 |
+
'480*832': 480 * 832,
|
| 46 |
+
'832*480': 832 * 480,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
SUPPORTED_SIZES = {
|
| 50 |
+
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 51 |
+
't2v-1.3B': ('480*832', '832*480'),
|
| 52 |
+
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 53 |
+
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 54 |
+
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
| 55 |
+
'vace-1.3B': ('480*832', '832*480'),
|
| 56 |
+
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 57 |
+
'infinitetalk-14B': ('infinitetalk-480', 'infinitetalk-720'),
|
| 58 |
+
}
|
wan/configs/shared_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
#------------------------ Wan shared config ------------------------#
|
| 6 |
+
wan_shared_cfg = EasyDict()
|
| 7 |
+
|
| 8 |
+
# t5
|
| 9 |
+
wan_shared_cfg.t5_model = 'umt5_xxl'
|
| 10 |
+
wan_shared_cfg.t5_dtype = torch.bfloat16
|
| 11 |
+
wan_shared_cfg.text_len = 512
|
| 12 |
+
|
| 13 |
+
# transformer
|
| 14 |
+
wan_shared_cfg.param_dtype = torch.bfloat16
|
| 15 |
+
|
| 16 |
+
# inference
|
| 17 |
+
wan_shared_cfg.num_train_timesteps = 1000
|
| 18 |
+
wan_shared_cfg.sample_fps = 16
|
| 19 |
+
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
wan/configs/wan_i2v_14B.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
#------------------------ Wan I2V 14B ------------------------#
|
| 8 |
+
|
| 9 |
+
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
|
| 10 |
+
i2v_14B.update(wan_shared_cfg)
|
| 11 |
+
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
|
| 12 |
+
|
| 13 |
+
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 14 |
+
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 15 |
+
|
| 16 |
+
# clip
|
| 17 |
+
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
| 18 |
+
i2v_14B.clip_dtype = torch.float16
|
| 19 |
+
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
| 20 |
+
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
| 21 |
+
|
| 22 |
+
# vae
|
| 23 |
+
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 24 |
+
i2v_14B.vae_stride = (4, 8, 8)
|
wan/configs/wan_multitalk_14B.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
#------------------------ Wan I2V 14B ------------------------#
|
| 8 |
+
|
| 9 |
+
multitalk_14B = EasyDict(__name__='Config: Wan MultiTalk AI2V 14B')
|
| 10 |
+
multitalk_14B.update(wan_shared_cfg)
|
| 11 |
+
multitalk_14B.sample_neg_prompt = 'bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards'
|
| 12 |
+
|
| 13 |
+
multitalk_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 14 |
+
multitalk_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 15 |
+
|
| 16 |
+
# clip
|
| 17 |
+
multitalk_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
| 18 |
+
multitalk_14B.clip_dtype = torch.float16
|
| 19 |
+
multitalk_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
| 20 |
+
multitalk_14B.clip_tokenizer = 'xlm-roberta-large'
|
| 21 |
+
|
| 22 |
+
# vae
|
| 23 |
+
multitalk_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 24 |
+
multitalk_14B.vae_stride = (4, 8, 8)
|
| 25 |
+
|
| 26 |
+
# transformer
|
| 27 |
+
multitalk_14B.patch_size = (1, 2, 2)
|
| 28 |
+
multitalk_14B.dim = 5120
|
| 29 |
+
multitalk_14B.ffn_dim = 13824
|
| 30 |
+
multitalk_14B.freq_dim = 256
|
| 31 |
+
multitalk_14B.num_heads = 40
|
| 32 |
+
multitalk_14B.num_layers = 40
|
| 33 |
+
multitalk_14B.window_size = (-1, -1)
|
| 34 |
+
multitalk_14B.qk_norm = True
|
| 35 |
+
multitalk_14B.cross_attn_norm = True
|
| 36 |
+
multitalk_14B.eps = 1e-6
|
wan/configs/wan_t2v_14B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V 14B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
|
| 9 |
+
t2v_14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_14B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_14B.dim = 5120
|
| 22 |
+
t2v_14B.ffn_dim = 13824
|
| 23 |
+
t2v_14B.freq_dim = 256
|
| 24 |
+
t2v_14B.num_heads = 40
|
| 25 |
+
t2v_14B.num_layers = 40
|
| 26 |
+
t2v_14B.window_size = (-1, -1)
|
| 27 |
+
t2v_14B.qk_norm = True
|
| 28 |
+
t2v_14B.cross_attn_norm = True
|
| 29 |
+
t2v_14B.eps = 1e-6
|
wan/configs/wan_t2v_1_3B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V 1.3B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
|
| 9 |
+
t2v_1_3B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_1_3B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_1_3B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_1_3B.dim = 1536
|
| 22 |
+
t2v_1_3B.ffn_dim = 8960
|
| 23 |
+
t2v_1_3B.freq_dim = 256
|
| 24 |
+
t2v_1_3B.num_heads = 12
|
| 25 |
+
t2v_1_3B.num_layers = 30
|
| 26 |
+
t2v_1_3B.window_size = (-1, -1)
|
| 27 |
+
t2v_1_3B.qk_norm = True
|
| 28 |
+
t2v_1_3B.cross_attn_norm = True
|
| 29 |
+
t2v_1_3B.eps = 1e-6
|
wan/distributed/__init__.py
ADDED
|
File without changes
|
wan/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 9 |
+
from torch.distributed.utils import _free_storage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def shard_model(
|
| 13 |
+
model,
|
| 14 |
+
device_id,
|
| 15 |
+
param_dtype=torch.bfloat16,
|
| 16 |
+
reduce_dtype=torch.float32,
|
| 17 |
+
buffer_dtype=torch.float32,
|
| 18 |
+
process_group=None,
|
| 19 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 20 |
+
sync_module_states=True,
|
| 21 |
+
):
|
| 22 |
+
model = FSDP(
|
| 23 |
+
module=model,
|
| 24 |
+
process_group=process_group,
|
| 25 |
+
sharding_strategy=sharding_strategy,
|
| 26 |
+
auto_wrap_policy=partial(
|
| 27 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 28 |
+
# mixed_precision=MixedPrecision(
|
| 29 |
+
# param_dtype=param_dtype,
|
| 30 |
+
# reduce_dtype=reduce_dtype,
|
| 31 |
+
# buffer_dtype=buffer_dtype),
|
| 32 |
+
device_id=device_id,
|
| 33 |
+
sync_module_states=sync_module_states)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def free_model(model):
|
| 38 |
+
for m in model.modules():
|
| 39 |
+
if isinstance(m, FSDP):
|
| 40 |
+
_free_storage(m._handle.flat_param.data)
|
| 41 |
+
del model
|
| 42 |
+
gc.collect()
|
| 43 |
+
torch.cuda.empty_cache()
|
wan/distributed/xdit_context_parallel.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
from xfuser.core.distributed import (
|
| 7 |
+
get_sequence_parallel_rank,
|
| 8 |
+
get_sequence_parallel_world_size,
|
| 9 |
+
get_sp_group,
|
| 10 |
+
)
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 13 |
+
import xformers.ops
|
| 14 |
+
|
| 15 |
+
from ..modules.model import sinusoidal_embedding_1d
|
| 16 |
+
from ..utils.multitalk_utils import get_attn_map_with_target, split_token_counts_and_frame_ids, normalize_and_scale
|
| 17 |
+
from ..modules.attention import SingleStreamAttention, SingleStreamMutiAttention
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pad_freqs(original_tensor, target_len):
|
| 21 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 22 |
+
pad_size = target_len - seq_len
|
| 23 |
+
padding_tensor = torch.ones(
|
| 24 |
+
pad_size,
|
| 25 |
+
s1,
|
| 26 |
+
s2,
|
| 27 |
+
dtype=original_tensor.dtype,
|
| 28 |
+
device=original_tensor.device)
|
| 29 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 30 |
+
return padded_tensor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@amp.autocast(enabled=False)
|
| 34 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 35 |
+
"""
|
| 36 |
+
x: [B, L, N, C].
|
| 37 |
+
grid_sizes: [B, 3].
|
| 38 |
+
freqs: [M, C // 2].
|
| 39 |
+
"""
|
| 40 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 41 |
+
# split freqs
|
| 42 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # [[N, head_dim/2], [N, head_dim/2], [N, head_dim/2]] # T H W 极坐标
|
| 43 |
+
|
| 44 |
+
# loop over samples
|
| 45 |
+
output = []
|
| 46 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 47 |
+
seq_len = f * h * w
|
| 48 |
+
|
| 49 |
+
# precompute multipliers
|
| 50 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 51 |
+
s, n, -1, 2)) # [L, N, C/2] # 极坐标
|
| 52 |
+
freqs_i = torch.cat([
|
| 53 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 54 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 55 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 56 |
+
],
|
| 57 |
+
dim=-1).reshape(seq_len, 1, -1) # seq_lens, 1, 3 * dim / 2 (T H W)
|
| 58 |
+
|
| 59 |
+
# apply rotary embedding
|
| 60 |
+
sp_size = get_sequence_parallel_world_size()
|
| 61 |
+
sp_rank = get_sequence_parallel_rank()
|
| 62 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 63 |
+
s_per_rank = s
|
| 64 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 65 |
+
s_per_rank), :, :]
|
| 66 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 67 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 68 |
+
|
| 69 |
+
# append to collection
|
| 70 |
+
output.append(x_i)
|
| 71 |
+
return torch.stack(output).float()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
|
| 75 |
+
# embeddings
|
| 76 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 77 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 78 |
+
c = torch.cat([
|
| 79 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 80 |
+
for u in c
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# arguments
|
| 84 |
+
new_kwargs = dict(x=x)
|
| 85 |
+
new_kwargs.update(kwargs)
|
| 86 |
+
|
| 87 |
+
# Context Parallel
|
| 88 |
+
c = torch.chunk(
|
| 89 |
+
c, get_sequence_parallel_world_size(),
|
| 90 |
+
dim=1)[get_sequence_parallel_rank()]
|
| 91 |
+
|
| 92 |
+
hints = []
|
| 93 |
+
for block in self.vace_blocks:
|
| 94 |
+
c, c_skip = block(c, **new_kwargs)
|
| 95 |
+
hints.append(c_skip)
|
| 96 |
+
return hints
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def usp_dit_forward(
|
| 100 |
+
self,
|
| 101 |
+
x,
|
| 102 |
+
t,
|
| 103 |
+
context,
|
| 104 |
+
seq_len,
|
| 105 |
+
vace_context=None,
|
| 106 |
+
vace_context_scale=1.0,
|
| 107 |
+
clip_fea=None,
|
| 108 |
+
y=None,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 112 |
+
t: [B].
|
| 113 |
+
context: A list of text embeddings each with shape [L, C].
|
| 114 |
+
"""
|
| 115 |
+
if self.model_type == 'i2v':
|
| 116 |
+
assert clip_fea is not None and y is not None
|
| 117 |
+
# params
|
| 118 |
+
device = self.patch_embedding.weight.device
|
| 119 |
+
if self.freqs.device != device:
|
| 120 |
+
self.freqs = self.freqs.to(device)
|
| 121 |
+
|
| 122 |
+
if self.model_type != 'vace' and y is not None:
|
| 123 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 124 |
+
|
| 125 |
+
# embeddings
|
| 126 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 127 |
+
grid_sizes = torch.stack(
|
| 128 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 129 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 130 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 131 |
+
assert seq_lens.max() <= seq_len
|
| 132 |
+
x = torch.cat([
|
| 133 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 134 |
+
for u in x
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
# time embeddings
|
| 138 |
+
with amp.autocast(dtype=torch.float32):
|
| 139 |
+
e = self.time_embedding(
|
| 140 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 141 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 142 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 143 |
+
|
| 144 |
+
# context
|
| 145 |
+
context_lens = None
|
| 146 |
+
context = self.text_embedding(
|
| 147 |
+
torch.stack([
|
| 148 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 149 |
+
for u in context
|
| 150 |
+
]))
|
| 151 |
+
|
| 152 |
+
if self.model_type != 'vace' and clip_fea is not None:
|
| 153 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 154 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 155 |
+
|
| 156 |
+
# arguments
|
| 157 |
+
kwargs = dict(
|
| 158 |
+
e=e0,
|
| 159 |
+
seq_lens=seq_lens,
|
| 160 |
+
grid_sizes=grid_sizes,
|
| 161 |
+
freqs=self.freqs,
|
| 162 |
+
context=context,
|
| 163 |
+
context_lens=context_lens)
|
| 164 |
+
|
| 165 |
+
# Context Parallel
|
| 166 |
+
x = torch.chunk(
|
| 167 |
+
x, get_sequence_parallel_world_size(),
|
| 168 |
+
dim=1)[get_sequence_parallel_rank()]
|
| 169 |
+
|
| 170 |
+
for block in self.blocks:
|
| 171 |
+
x = block(x, **kwargs)
|
| 172 |
+
|
| 173 |
+
# head
|
| 174 |
+
x = self.head(x, e)
|
| 175 |
+
|
| 176 |
+
# Context Parallel
|
| 177 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 178 |
+
|
| 179 |
+
# unpatchify
|
| 180 |
+
x = self.unpatchify(x, grid_sizes)
|
| 181 |
+
return [u.float() for u in x]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def usp_attn_forward(self,
|
| 185 |
+
x,
|
| 186 |
+
seq_lens,
|
| 187 |
+
grid_sizes,
|
| 188 |
+
freqs,
|
| 189 |
+
dtype=torch.bfloat16):
|
| 190 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 191 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 192 |
+
|
| 193 |
+
def half(x):
|
| 194 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 195 |
+
|
| 196 |
+
# query, key, value function
|
| 197 |
+
def qkv_fn(x):
|
| 198 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 199 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 200 |
+
v = self.v(x).view(b, s, n, d)
|
| 201 |
+
return q, k, v
|
| 202 |
+
|
| 203 |
+
q, k, v = qkv_fn(x)
|
| 204 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 205 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 206 |
+
|
| 207 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 208 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 209 |
+
# if k_lens is not None:
|
| 210 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 211 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 212 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 213 |
+
|
| 214 |
+
x = xFuserLongContextAttention()(
|
| 215 |
+
None,
|
| 216 |
+
query=half(q),
|
| 217 |
+
key=half(k),
|
| 218 |
+
value=half(v),
|
| 219 |
+
window_size=self.window_size)
|
| 220 |
+
|
| 221 |
+
# TODO: padding after attention.
|
| 222 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 223 |
+
|
| 224 |
+
# output
|
| 225 |
+
x = x.flatten(2)
|
| 226 |
+
x = self.o(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def usp_dit_forward_multitalk(
|
| 233 |
+
self,
|
| 234 |
+
x,
|
| 235 |
+
t,
|
| 236 |
+
context,
|
| 237 |
+
seq_len,
|
| 238 |
+
clip_fea=None,
|
| 239 |
+
y=None,
|
| 240 |
+
audio=None,
|
| 241 |
+
ref_target_masks=None,
|
| 242 |
+
):
|
| 243 |
+
"""
|
| 244 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 245 |
+
t: [B].
|
| 246 |
+
context: A list of text embeddings each with shape [L, C].
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
assert clip_fea is not None and y is not None
|
| 250 |
+
# params
|
| 251 |
+
device = self.patch_embedding.weight.device
|
| 252 |
+
if self.freqs.device != device:
|
| 253 |
+
self.freqs = self.freqs.to(device)
|
| 254 |
+
|
| 255 |
+
_, T, H, W = x[0].shape
|
| 256 |
+
N_t = T // self.patch_size[0]
|
| 257 |
+
N_h = H // self.patch_size[1]
|
| 258 |
+
N_w = W // self.patch_size[2]
|
| 259 |
+
|
| 260 |
+
if y is not None:
|
| 261 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 262 |
+
x[0] = x[0].to(context[0].dtype)
|
| 263 |
+
|
| 264 |
+
# embeddings
|
| 265 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 266 |
+
grid_sizes = torch.stack(
|
| 267 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 268 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 269 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 270 |
+
assert seq_lens.max() <= seq_len
|
| 271 |
+
x = torch.cat([
|
| 272 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 273 |
+
for u in x
|
| 274 |
+
])
|
| 275 |
+
|
| 276 |
+
# time embeddings
|
| 277 |
+
with amp.autocast(dtype=torch.float32):
|
| 278 |
+
e = self.time_embedding(
|
| 279 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 280 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 281 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 282 |
+
|
| 283 |
+
# context
|
| 284 |
+
context_lens = None
|
| 285 |
+
context = self.text_embedding(
|
| 286 |
+
torch.stack([
|
| 287 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 288 |
+
for u in context
|
| 289 |
+
]))
|
| 290 |
+
|
| 291 |
+
if clip_fea is not None:
|
| 292 |
+
context_clip = self.img_emb(clip_fea)
|
| 293 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 294 |
+
|
| 295 |
+
# get audio token
|
| 296 |
+
audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
| 297 |
+
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
| 298 |
+
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
| 299 |
+
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
|
| 300 |
+
middle_index = self.audio_window // 2
|
| 301 |
+
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
| 302 |
+
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 303 |
+
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
| 304 |
+
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 305 |
+
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
| 306 |
+
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 307 |
+
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
| 308 |
+
audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
| 309 |
+
human_num = len(audio_embedding)
|
| 310 |
+
audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# convert ref_target_masks to token_ref_target_masks
|
| 314 |
+
if ref_target_masks is not None:
|
| 315 |
+
ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32)
|
| 316 |
+
token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
|
| 317 |
+
token_ref_target_masks = token_ref_target_masks.squeeze(0)
|
| 318 |
+
token_ref_target_masks = (token_ref_target_masks > 0)
|
| 319 |
+
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
| 320 |
+
token_ref_target_masks = token_ref_target_masks.to(x.dtype)
|
| 321 |
+
|
| 322 |
+
if self.enable_teacache:
|
| 323 |
+
modulated_inp = e0 if self.use_ret_steps else e
|
| 324 |
+
if self.cnt%3==0: # cond
|
| 325 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 326 |
+
should_calc_cond = True
|
| 327 |
+
self.accumulated_rel_l1_distance_cond = 0
|
| 328 |
+
else:
|
| 329 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 330 |
+
self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
|
| 331 |
+
# print("accumulated_rel_l1_distance_even", self.accumulated_rel_l1_distance_even)
|
| 332 |
+
if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
|
| 333 |
+
should_calc_cond = False
|
| 334 |
+
else:
|
| 335 |
+
should_calc_cond = True
|
| 336 |
+
self.accumulated_rel_l1_distance_cond = 0
|
| 337 |
+
self.previous_e0_cond = modulated_inp.clone()
|
| 338 |
+
elif self.cnt%3==1: # drop_text
|
| 339 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 340 |
+
should_calc_drop_text = True
|
| 341 |
+
self.accumulated_rel_l1_distance_drop_text = 0
|
| 342 |
+
else:
|
| 343 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 344 |
+
self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
|
| 345 |
+
if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
|
| 346 |
+
should_calc_drop_text = False
|
| 347 |
+
else:
|
| 348 |
+
should_calc_drop_text = True
|
| 349 |
+
self.accumulated_rel_l1_distance_drop_text = 0
|
| 350 |
+
self.previous_e0_drop_text = modulated_inp.clone()
|
| 351 |
+
else: # uncond
|
| 352 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 353 |
+
should_calc_uncond = True
|
| 354 |
+
self.accumulated_rel_l1_distance_uncond = 0
|
| 355 |
+
else:
|
| 356 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 357 |
+
self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
|
| 358 |
+
if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
|
| 359 |
+
should_calc_uncond = False
|
| 360 |
+
else:
|
| 361 |
+
should_calc_uncond = True
|
| 362 |
+
self.accumulated_rel_l1_distance_uncond = 0
|
| 363 |
+
self.previous_e0_uncond = modulated_inp.clone()
|
| 364 |
+
|
| 365 |
+
# Context Parallel
|
| 366 |
+
x = torch.chunk(
|
| 367 |
+
x, get_sequence_parallel_world_size(),
|
| 368 |
+
dim=1)[get_sequence_parallel_rank()]
|
| 369 |
+
|
| 370 |
+
# arguments
|
| 371 |
+
kwargs = dict(
|
| 372 |
+
e=e0,
|
| 373 |
+
seq_lens=seq_lens,
|
| 374 |
+
grid_sizes=grid_sizes,
|
| 375 |
+
freqs=self.freqs,
|
| 376 |
+
context=context,
|
| 377 |
+
context_lens=context_lens,
|
| 378 |
+
audio_embedding=audio_embedding,
|
| 379 |
+
ref_target_masks=token_ref_target_masks,
|
| 380 |
+
human_num=human_num,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if self.enable_teacache:
|
| 384 |
+
if self.cnt%3==0:
|
| 385 |
+
if not should_calc_cond:
|
| 386 |
+
x += self.previous_residual_cond
|
| 387 |
+
else:
|
| 388 |
+
ori_x = x.clone()
|
| 389 |
+
for block in self.blocks:
|
| 390 |
+
x = block(x, **kwargs)
|
| 391 |
+
self.previous_residual_cond = x - ori_x
|
| 392 |
+
elif self.cnt%3==1:
|
| 393 |
+
if not should_calc_drop_text:
|
| 394 |
+
x += self.previous_residual_drop_text
|
| 395 |
+
else:
|
| 396 |
+
ori_x = x.clone()
|
| 397 |
+
for block in self.blocks:
|
| 398 |
+
x = block(x, **kwargs)
|
| 399 |
+
self.previous_residual_drop_text = x - ori_x
|
| 400 |
+
else:
|
| 401 |
+
if not should_calc_uncond:
|
| 402 |
+
x += self.previous_residual_uncond
|
| 403 |
+
else:
|
| 404 |
+
ori_x = x.clone()
|
| 405 |
+
for block in self.blocks:
|
| 406 |
+
x = block(x, **kwargs)
|
| 407 |
+
self.previous_residual_uncond = x - ori_x
|
| 408 |
+
else:
|
| 409 |
+
for block in self.blocks:
|
| 410 |
+
x = block(x, **kwargs)
|
| 411 |
+
|
| 412 |
+
# head
|
| 413 |
+
x = self.head(x, e)
|
| 414 |
+
|
| 415 |
+
# Context Parallel
|
| 416 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 417 |
+
|
| 418 |
+
# unpatchify
|
| 419 |
+
x = self.unpatchify(x, grid_sizes)
|
| 420 |
+
if self.enable_teacache:
|
| 421 |
+
self.cnt += 1
|
| 422 |
+
if self.cnt >= self.num_steps:
|
| 423 |
+
self.cnt = 0
|
| 424 |
+
|
| 425 |
+
return torch.stack(x).float()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def usp_attn_forward_multitalk(self,
|
| 429 |
+
x,
|
| 430 |
+
seq_lens,
|
| 431 |
+
grid_sizes,
|
| 432 |
+
freqs,
|
| 433 |
+
dtype=torch.bfloat16,
|
| 434 |
+
ref_target_masks=None):
|
| 435 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 436 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 437 |
+
|
| 438 |
+
def half(x):
|
| 439 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 440 |
+
|
| 441 |
+
# query, key, value function
|
| 442 |
+
def qkv_fn(x):
|
| 443 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 444 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 445 |
+
v = self.v(x).view(b, s, n, d)
|
| 446 |
+
return q, k, v
|
| 447 |
+
|
| 448 |
+
q, k, v = qkv_fn(x)
|
| 449 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 450 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
x = xFuserLongContextAttention()(
|
| 454 |
+
None,
|
| 455 |
+
query=half(q),
|
| 456 |
+
key=half(k),
|
| 457 |
+
value=half(v),
|
| 458 |
+
window_size=self.window_size)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# output
|
| 462 |
+
x = x.flatten(2)
|
| 463 |
+
x = self.o(x)
|
| 464 |
+
|
| 465 |
+
with torch.no_grad():
|
| 466 |
+
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0],
|
| 467 |
+
ref_target_masks=ref_target_masks, enable_sp=True)
|
| 468 |
+
|
| 469 |
+
return x, x_ref_attn_map
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def usp_crossattn_multi_forward_multitalk(self,
|
| 475 |
+
x: torch.Tensor,
|
| 476 |
+
encoder_hidden_states: torch.Tensor, # 1, 21, 64, C
|
| 477 |
+
shape=None,
|
| 478 |
+
x_ref_attn_map=None,
|
| 479 |
+
human_num=None) -> torch.Tensor:
|
| 480 |
+
|
| 481 |
+
N_t, N_h, N_w = shape
|
| 482 |
+
sp_size = get_sequence_parallel_world_size()
|
| 483 |
+
sp_rank = get_sequence_parallel_rank()
|
| 484 |
+
audio_tokens_per_frame = 32
|
| 485 |
+
visual_seqlen, frame_ids = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank)
|
| 486 |
+
encoder_hidden_states = encoder_hidden_states[:, min(frame_ids):max(frame_ids)+1, ...]
|
| 487 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, "B T N C -> B (T N) C")
|
| 488 |
+
N_a = len(frame_ids)
|
| 489 |
+
kv_seq = [audio_tokens_per_frame * human_num] * N_a
|
| 490 |
+
|
| 491 |
+
if human_num == 1:
|
| 492 |
+
return super(SingleStreamMutiAttention, self).forward(x, encoder_hidden_states, shape, enable_sp=True, kv_seq=kv_seq)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
# get q for hidden_state
|
| 496 |
+
B, N, C = x.shape
|
| 497 |
+
q = self.q_linear(x)
|
| 498 |
+
q_shape = (B, N, self.num_heads, self.head_dim)
|
| 499 |
+
q = q.view(q_shape).permute((0, 2, 1, 3))
|
| 500 |
+
|
| 501 |
+
if self.qk_norm:
|
| 502 |
+
q = self.q_norm(q)
|
| 503 |
+
|
| 504 |
+
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
| 505 |
+
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
| 506 |
+
max_min_values = torch.cat([max_values, min_values], dim=2)
|
| 507 |
+
max_min_values = get_sp_group().all_gather(max_min_values, dim=1)
|
| 508 |
+
|
| 509 |
+
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
| 510 |
+
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
| 511 |
+
|
| 512 |
+
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
|
| 513 |
+
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
|
| 514 |
+
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
|
| 515 |
+
max_indices = x_ref_attn_map.argmax(dim=0)
|
| 516 |
+
normalized_map = torch.stack([human1, human2, back], dim=1)
|
| 517 |
+
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
|
| 518 |
+
q = self.rope_1d(q, normalized_pos)
|
| 519 |
+
|
| 520 |
+
encoder_kv = self.kv_linear(encoder_hidden_states)
|
| 521 |
+
encoder_kv_shape = (B, encoder_hidden_states.size(1), 2, self.num_heads, self.head_dim)
|
| 522 |
+
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
| 523 |
+
encoder_k, encoder_v = encoder_kv.unbind(0) # B H N C
|
| 524 |
+
|
| 525 |
+
if self.qk_norm:
|
| 526 |
+
encoder_k = self.add_k_norm(encoder_k)
|
| 527 |
+
|
| 528 |
+
# position embedding for condition audio embeddings
|
| 529 |
+
per_frame = torch.zeros(audio_tokens_per_frame * human_num, dtype=encoder_k.dtype).to(encoder_k.device)
|
| 530 |
+
per_frame[:audio_tokens_per_frame] = (self.rope_h1[0] + self.rope_h1[1]) / 2
|
| 531 |
+
per_frame[audio_tokens_per_frame:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
|
| 532 |
+
encoder_pos = torch.concat([per_frame]*N_a, dim=0)
|
| 533 |
+
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
| 534 |
+
|
| 535 |
+
# get attn
|
| 536 |
+
q = rearrange(q, "B H M K -> B M H K")
|
| 537 |
+
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
| 538 |
+
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
| 539 |
+
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq)
|
| 540 |
+
x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
|
| 541 |
+
x = rearrange(x, "B M H K -> B H M K")
|
| 542 |
+
|
| 543 |
+
# linear transform
|
| 544 |
+
x_output_shape = (B, N, C)
|
| 545 |
+
x = x.transpose(1, 2)
|
| 546 |
+
x = x.reshape(x_output_shape)
|
| 547 |
+
x = self.proj(x)
|
| 548 |
+
x = self.proj_drop(x)
|
| 549 |
+
|
| 550 |
+
return x
|
wan/first_last_frame2video.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .modules.clip import CLIPModel
|
| 21 |
+
from .modules.model import WanModel
|
| 22 |
+
from .modules.t5 import T5EncoderModel
|
| 23 |
+
from .modules.vae import WanVAE
|
| 24 |
+
from .utils.fm_solvers import (
|
| 25 |
+
FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas,
|
| 27 |
+
retrieve_timesteps,
|
| 28 |
+
)
|
| 29 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WanFLF2V:
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
config,
|
| 37 |
+
checkpoint_dir,
|
| 38 |
+
device_id=0,
|
| 39 |
+
rank=0,
|
| 40 |
+
t5_fsdp=False,
|
| 41 |
+
dit_fsdp=False,
|
| 42 |
+
use_usp=False,
|
| 43 |
+
t5_cpu=False,
|
| 44 |
+
init_on_cpu=True,
|
| 45 |
+
):
|
| 46 |
+
r"""
|
| 47 |
+
Initializes the image-to-video generation model components.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config (EasyDict):
|
| 51 |
+
Object containing model parameters initialized from config.py
|
| 52 |
+
checkpoint_dir (`str`):
|
| 53 |
+
Path to directory containing model checkpoints
|
| 54 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 55 |
+
Id of target GPU device
|
| 56 |
+
rank (`int`, *optional*, defaults to 0):
|
| 57 |
+
Process rank for distributed training
|
| 58 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 59 |
+
Enable FSDP sharding for T5 model
|
| 60 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 61 |
+
Enable FSDP sharding for DiT model
|
| 62 |
+
use_usp (`bool`, *optional*, defaults to False):
|
| 63 |
+
Enable distribution strategy of USP.
|
| 64 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 65 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 66 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 67 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 68 |
+
"""
|
| 69 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 70 |
+
self.config = config
|
| 71 |
+
self.rank = rank
|
| 72 |
+
self.use_usp = use_usp
|
| 73 |
+
self.t5_cpu = t5_cpu
|
| 74 |
+
|
| 75 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 76 |
+
self.param_dtype = config.param_dtype
|
| 77 |
+
|
| 78 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 79 |
+
self.text_encoder = T5EncoderModel(
|
| 80 |
+
text_len=config.text_len,
|
| 81 |
+
dtype=config.t5_dtype,
|
| 82 |
+
device=torch.device('cpu'),
|
| 83 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 84 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 85 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.vae_stride = config.vae_stride
|
| 89 |
+
self.patch_size = config.patch_size
|
| 90 |
+
self.vae = WanVAE(
|
| 91 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 92 |
+
device=self.device)
|
| 93 |
+
|
| 94 |
+
self.clip = CLIPModel(
|
| 95 |
+
dtype=config.clip_dtype,
|
| 96 |
+
device=self.device,
|
| 97 |
+
checkpoint_path=os.path.join(checkpoint_dir,
|
| 98 |
+
config.clip_checkpoint),
|
| 99 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
| 100 |
+
|
| 101 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 102 |
+
self.model = WanModel.from_pretrained(checkpoint_dir)
|
| 103 |
+
self.model.eval().requires_grad_(False)
|
| 104 |
+
|
| 105 |
+
if t5_fsdp or dit_fsdp or use_usp:
|
| 106 |
+
init_on_cpu = False
|
| 107 |
+
|
| 108 |
+
if use_usp:
|
| 109 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
| 110 |
+
|
| 111 |
+
from .distributed.xdit_context_parallel import (
|
| 112 |
+
usp_attn_forward,
|
| 113 |
+
usp_dit_forward,
|
| 114 |
+
)
|
| 115 |
+
for block in self.model.blocks:
|
| 116 |
+
block.self_attn.forward = types.MethodType(
|
| 117 |
+
usp_attn_forward, block.self_attn)
|
| 118 |
+
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
| 119 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 120 |
+
else:
|
| 121 |
+
self.sp_size = 1
|
| 122 |
+
|
| 123 |
+
if dist.is_initialized():
|
| 124 |
+
dist.barrier()
|
| 125 |
+
if dit_fsdp:
|
| 126 |
+
self.model = shard_fn(self.model)
|
| 127 |
+
else:
|
| 128 |
+
if not init_on_cpu:
|
| 129 |
+
self.model.to(self.device)
|
| 130 |
+
|
| 131 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 132 |
+
|
| 133 |
+
def generate(self,
|
| 134 |
+
input_prompt,
|
| 135 |
+
first_frame,
|
| 136 |
+
last_frame,
|
| 137 |
+
max_area=720 * 1280,
|
| 138 |
+
frame_num=81,
|
| 139 |
+
shift=16,
|
| 140 |
+
sample_solver='unipc',
|
| 141 |
+
sampling_steps=50,
|
| 142 |
+
guide_scale=5.5,
|
| 143 |
+
n_prompt="",
|
| 144 |
+
seed=-1,
|
| 145 |
+
offload_model=True):
|
| 146 |
+
r"""
|
| 147 |
+
Generates video frames from input first-last frame and text prompt using diffusion process.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
input_prompt (`str`):
|
| 151 |
+
Text prompt for content generation.
|
| 152 |
+
first_frame (PIL.Image.Image):
|
| 153 |
+
Input image tensor. Shape: [3, H, W]
|
| 154 |
+
last_frame (PIL.Image.Image):
|
| 155 |
+
Input image tensor. Shape: [3, H, W]
|
| 156 |
+
[NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
|
| 157 |
+
to match first_frame.
|
| 158 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 159 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 160 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 161 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 162 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 163 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 164 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 165 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 166 |
+
Solver used to sample the video.
|
| 167 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 168 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 169 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 170 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
| 171 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 172 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 173 |
+
seed (`int`, *optional*, defaults to -1):
|
| 174 |
+
Random seed for noise generation. If -1, use random seed
|
| 175 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 176 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
torch.Tensor:
|
| 180 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 181 |
+
- C: Color channels (3 for RGB)
|
| 182 |
+
- N: Number of frames (81)
|
| 183 |
+
- H: Frame height (from max_area)
|
| 184 |
+
- W: Frame width from max_area)
|
| 185 |
+
"""
|
| 186 |
+
first_frame_size = first_frame.size
|
| 187 |
+
last_frame_size = last_frame.size
|
| 188 |
+
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
| 189 |
+
self.device)
|
| 190 |
+
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
|
| 191 |
+
self.device)
|
| 192 |
+
|
| 193 |
+
F = frame_num
|
| 194 |
+
first_frame_h, first_frame_w = first_frame.shape[1:]
|
| 195 |
+
aspect_ratio = first_frame_h / first_frame_w
|
| 196 |
+
lat_h = round(
|
| 197 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 198 |
+
self.patch_size[1] * self.patch_size[1])
|
| 199 |
+
lat_w = round(
|
| 200 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 201 |
+
self.patch_size[2] * self.patch_size[2])
|
| 202 |
+
first_frame_h = lat_h * self.vae_stride[1]
|
| 203 |
+
first_frame_w = lat_w * self.vae_stride[2]
|
| 204 |
+
if first_frame_size != last_frame_size:
|
| 205 |
+
# 1. resize
|
| 206 |
+
last_frame_resize_ratio = max(
|
| 207 |
+
first_frame_size[0] / last_frame_size[0],
|
| 208 |
+
first_frame_size[1] / last_frame_size[1])
|
| 209 |
+
last_frame_size = [
|
| 210 |
+
round(last_frame_size[0] * last_frame_resize_ratio),
|
| 211 |
+
round(last_frame_size[1] * last_frame_resize_ratio),
|
| 212 |
+
]
|
| 213 |
+
# 2. center crop
|
| 214 |
+
last_frame = TF.center_crop(last_frame, last_frame_size)
|
| 215 |
+
|
| 216 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
| 217 |
+
self.patch_size[1] * self.patch_size[2])
|
| 218 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 219 |
+
|
| 220 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 221 |
+
seed_g = torch.Generator(device=self.device)
|
| 222 |
+
seed_g.manual_seed(seed)
|
| 223 |
+
noise = torch.randn(
|
| 224 |
+
16, (F - 1) // 4 + 1,
|
| 225 |
+
lat_h,
|
| 226 |
+
lat_w,
|
| 227 |
+
dtype=torch.float32,
|
| 228 |
+
generator=seed_g,
|
| 229 |
+
device=self.device)
|
| 230 |
+
|
| 231 |
+
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
| 232 |
+
msk[:, 1:-1] = 0
|
| 233 |
+
msk = torch.concat([
|
| 234 |
+
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
| 235 |
+
],
|
| 236 |
+
dim=1)
|
| 237 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 238 |
+
msk = msk.transpose(1, 2)[0]
|
| 239 |
+
|
| 240 |
+
if n_prompt == "":
|
| 241 |
+
n_prompt = self.sample_neg_prompt
|
| 242 |
+
|
| 243 |
+
# preprocess
|
| 244 |
+
if not self.t5_cpu:
|
| 245 |
+
self.text_encoder.model.to(self.device)
|
| 246 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 247 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 248 |
+
if offload_model:
|
| 249 |
+
self.text_encoder.model.cpu()
|
| 250 |
+
else:
|
| 251 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 252 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 253 |
+
context = [t.to(self.device) for t in context]
|
| 254 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 255 |
+
|
| 256 |
+
self.clip.model.to(self.device)
|
| 257 |
+
clip_context = self.clip.visual(
|
| 258 |
+
[first_frame[:, None, :, :], last_frame[:, None, :, :]])
|
| 259 |
+
if offload_model:
|
| 260 |
+
self.clip.model.cpu()
|
| 261 |
+
|
| 262 |
+
y = self.vae.encode([
|
| 263 |
+
torch.concat([
|
| 264 |
+
torch.nn.functional.interpolate(
|
| 265 |
+
first_frame[None].cpu(),
|
| 266 |
+
size=(first_frame_h, first_frame_w),
|
| 267 |
+
mode='bicubic').transpose(0, 1),
|
| 268 |
+
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
|
| 269 |
+
torch.nn.functional.interpolate(
|
| 270 |
+
last_frame[None].cpu(),
|
| 271 |
+
size=(first_frame_h, first_frame_w),
|
| 272 |
+
mode='bicubic').transpose(0, 1),
|
| 273 |
+
],
|
| 274 |
+
dim=1).to(self.device)
|
| 275 |
+
])[0]
|
| 276 |
+
y = torch.concat([msk, y])
|
| 277 |
+
|
| 278 |
+
@contextmanager
|
| 279 |
+
def noop_no_sync():
|
| 280 |
+
yield
|
| 281 |
+
|
| 282 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 283 |
+
|
| 284 |
+
# evaluation mode
|
| 285 |
+
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 286 |
+
|
| 287 |
+
if sample_solver == 'unipc':
|
| 288 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 289 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 290 |
+
shift=1,
|
| 291 |
+
use_dynamic_shifting=False)
|
| 292 |
+
sample_scheduler.set_timesteps(
|
| 293 |
+
sampling_steps, device=self.device, shift=shift)
|
| 294 |
+
timesteps = sample_scheduler.timesteps
|
| 295 |
+
elif sample_solver == 'dpm++':
|
| 296 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 297 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 298 |
+
shift=1,
|
| 299 |
+
use_dynamic_shifting=False)
|
| 300 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 301 |
+
timesteps, _ = retrieve_timesteps(
|
| 302 |
+
sample_scheduler,
|
| 303 |
+
device=self.device,
|
| 304 |
+
sigmas=sampling_sigmas)
|
| 305 |
+
else:
|
| 306 |
+
raise NotImplementedError("Unsupported solver.")
|
| 307 |
+
|
| 308 |
+
# sample videos
|
| 309 |
+
latent = noise
|
| 310 |
+
|
| 311 |
+
arg_c = {
|
| 312 |
+
'context': [context[0]],
|
| 313 |
+
'clip_fea': clip_context,
|
| 314 |
+
'seq_len': max_seq_len,
|
| 315 |
+
'y': [y],
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
arg_null = {
|
| 319 |
+
'context': context_null,
|
| 320 |
+
'clip_fea': clip_context,
|
| 321 |
+
'seq_len': max_seq_len,
|
| 322 |
+
'y': [y],
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
if offload_model:
|
| 326 |
+
torch.cuda.empty_cache()
|
| 327 |
+
|
| 328 |
+
self.model.to(self.device)
|
| 329 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 330 |
+
latent_model_input = [latent.to(self.device)]
|
| 331 |
+
timestep = [t]
|
| 332 |
+
|
| 333 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 334 |
+
|
| 335 |
+
noise_pred_cond = self.model(
|
| 336 |
+
latent_model_input, t=timestep, **arg_c)[0].to(
|
| 337 |
+
torch.device('cpu') if offload_model else self.device)
|
| 338 |
+
if offload_model:
|
| 339 |
+
torch.cuda.empty_cache()
|
| 340 |
+
noise_pred_uncond = self.model(
|
| 341 |
+
latent_model_input, t=timestep, **arg_null)[0].to(
|
| 342 |
+
torch.device('cpu') if offload_model else self.device)
|
| 343 |
+
if offload_model:
|
| 344 |
+
torch.cuda.empty_cache()
|
| 345 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 346 |
+
noise_pred_cond - noise_pred_uncond)
|
| 347 |
+
|
| 348 |
+
latent = latent.to(
|
| 349 |
+
torch.device('cpu') if offload_model else self.device)
|
| 350 |
+
|
| 351 |
+
temp_x0 = sample_scheduler.step(
|
| 352 |
+
noise_pred.unsqueeze(0),
|
| 353 |
+
t,
|
| 354 |
+
latent.unsqueeze(0),
|
| 355 |
+
return_dict=False,
|
| 356 |
+
generator=seed_g)[0]
|
| 357 |
+
latent = temp_x0.squeeze(0)
|
| 358 |
+
|
| 359 |
+
x0 = [latent.to(self.device)]
|
| 360 |
+
del latent_model_input, timestep
|
| 361 |
+
|
| 362 |
+
if offload_model:
|
| 363 |
+
self.model.cpu()
|
| 364 |
+
torch.cuda.empty_cache()
|
| 365 |
+
|
| 366 |
+
if self.rank == 0:
|
| 367 |
+
videos = self.vae.decode(x0)
|
| 368 |
+
|
| 369 |
+
del noise, latent
|
| 370 |
+
del sample_scheduler
|
| 371 |
+
if offload_model:
|
| 372 |
+
gc.collect()
|
| 373 |
+
torch.cuda.synchronize()
|
| 374 |
+
if dist.is_initialized():
|
| 375 |
+
dist.barrier()
|
| 376 |
+
|
| 377 |
+
return videos[0] if self.rank == 0 else None
|
wan/image2video.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .modules.clip import CLIPModel
|
| 21 |
+
from .modules.model import WanModel
|
| 22 |
+
from .modules.t5 import T5EncoderModel
|
| 23 |
+
from .modules.vae import WanVAE
|
| 24 |
+
from .utils.fm_solvers import (
|
| 25 |
+
FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas,
|
| 27 |
+
retrieve_timesteps,
|
| 28 |
+
)
|
| 29 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WanI2V:
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
config,
|
| 37 |
+
checkpoint_dir,
|
| 38 |
+
device_id=0,
|
| 39 |
+
rank=0,
|
| 40 |
+
t5_fsdp=False,
|
| 41 |
+
dit_fsdp=False,
|
| 42 |
+
use_usp=False,
|
| 43 |
+
t5_cpu=False,
|
| 44 |
+
init_on_cpu=True,
|
| 45 |
+
):
|
| 46 |
+
r"""
|
| 47 |
+
Initializes the image-to-video generation model components.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config (EasyDict):
|
| 51 |
+
Object containing model parameters initialized from config.py
|
| 52 |
+
checkpoint_dir (`str`):
|
| 53 |
+
Path to directory containing model checkpoints
|
| 54 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 55 |
+
Id of target GPU device
|
| 56 |
+
rank (`int`, *optional*, defaults to 0):
|
| 57 |
+
Process rank for distributed training
|
| 58 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 59 |
+
Enable FSDP sharding for T5 model
|
| 60 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 61 |
+
Enable FSDP sharding for DiT model
|
| 62 |
+
use_usp (`bool`, *optional*, defaults to False):
|
| 63 |
+
Enable distribution strategy of USP.
|
| 64 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 65 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 66 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 67 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 68 |
+
"""
|
| 69 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 70 |
+
self.config = config
|
| 71 |
+
self.rank = rank
|
| 72 |
+
self.use_usp = use_usp
|
| 73 |
+
self.t5_cpu = t5_cpu
|
| 74 |
+
|
| 75 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 76 |
+
self.param_dtype = config.param_dtype
|
| 77 |
+
|
| 78 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 79 |
+
self.text_encoder = T5EncoderModel(
|
| 80 |
+
text_len=config.text_len,
|
| 81 |
+
dtype=config.t5_dtype,
|
| 82 |
+
device=torch.device('cpu'),
|
| 83 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 84 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 85 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.vae_stride = config.vae_stride
|
| 89 |
+
self.patch_size = config.patch_size
|
| 90 |
+
self.vae = WanVAE(
|
| 91 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 92 |
+
device=self.device)
|
| 93 |
+
|
| 94 |
+
self.clip = CLIPModel(
|
| 95 |
+
dtype=config.clip_dtype,
|
| 96 |
+
device=self.device,
|
| 97 |
+
checkpoint_path=os.path.join(checkpoint_dir,
|
| 98 |
+
config.clip_checkpoint),
|
| 99 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
| 100 |
+
|
| 101 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 102 |
+
self.model = WanModel.from_pretrained(checkpoint_dir)
|
| 103 |
+
self.model.eval().requires_grad_(False)
|
| 104 |
+
|
| 105 |
+
if t5_fsdp or dit_fsdp or use_usp:
|
| 106 |
+
init_on_cpu = False
|
| 107 |
+
|
| 108 |
+
if use_usp:
|
| 109 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
| 110 |
+
|
| 111 |
+
from .distributed.xdit_context_parallel import (
|
| 112 |
+
usp_attn_forward,
|
| 113 |
+
usp_dit_forward,
|
| 114 |
+
)
|
| 115 |
+
for block in self.model.blocks:
|
| 116 |
+
block.self_attn.forward = types.MethodType(
|
| 117 |
+
usp_attn_forward, block.self_attn)
|
| 118 |
+
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
| 119 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 120 |
+
else:
|
| 121 |
+
self.sp_size = 1
|
| 122 |
+
|
| 123 |
+
if dist.is_initialized():
|
| 124 |
+
dist.barrier()
|
| 125 |
+
if dit_fsdp:
|
| 126 |
+
self.model = shard_fn(self.model)
|
| 127 |
+
else:
|
| 128 |
+
if not init_on_cpu:
|
| 129 |
+
self.model.to(self.device)
|
| 130 |
+
|
| 131 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 132 |
+
|
| 133 |
+
def generate(self,
|
| 134 |
+
input_prompt,
|
| 135 |
+
img,
|
| 136 |
+
max_area=720 * 1280,
|
| 137 |
+
frame_num=81,
|
| 138 |
+
shift=5.0,
|
| 139 |
+
sample_solver='unipc',
|
| 140 |
+
sampling_steps=40,
|
| 141 |
+
guide_scale=5.0,
|
| 142 |
+
n_prompt="",
|
| 143 |
+
seed=-1,
|
| 144 |
+
offload_model=True):
|
| 145 |
+
r"""
|
| 146 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
input_prompt (`str`):
|
| 150 |
+
Text prompt for content generation.
|
| 151 |
+
img (PIL.Image.Image):
|
| 152 |
+
Input image tensor. Shape: [3, H, W]
|
| 153 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 154 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 155 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 156 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 157 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 158 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 159 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 160 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 161 |
+
Solver used to sample the video.
|
| 162 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 163 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 164 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 165 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
| 166 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 167 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 168 |
+
seed (`int`, *optional*, defaults to -1):
|
| 169 |
+
Random seed for noise generation. If -1, use random seed
|
| 170 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 171 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor:
|
| 175 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 176 |
+
- C: Color channels (3 for RGB)
|
| 177 |
+
- N: Number of frames (81)
|
| 178 |
+
- H: Frame height (from max_area)
|
| 179 |
+
- W: Frame width from max_area)
|
| 180 |
+
"""
|
| 181 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
| 182 |
+
|
| 183 |
+
F = frame_num
|
| 184 |
+
h, w = img.shape[1:]
|
| 185 |
+
aspect_ratio = h / w
|
| 186 |
+
lat_h = round(
|
| 187 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 188 |
+
self.patch_size[1] * self.patch_size[1])
|
| 189 |
+
lat_w = round(
|
| 190 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 191 |
+
self.patch_size[2] * self.patch_size[2])
|
| 192 |
+
h = lat_h * self.vae_stride[1]
|
| 193 |
+
w = lat_w * self.vae_stride[2]
|
| 194 |
+
|
| 195 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
| 196 |
+
self.patch_size[1] * self.patch_size[2])
|
| 197 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 198 |
+
|
| 199 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 200 |
+
seed_g = torch.Generator(device=self.device)
|
| 201 |
+
seed_g.manual_seed(seed)
|
| 202 |
+
noise = torch.randn(
|
| 203 |
+
16, (F - 1) // 4 + 1,
|
| 204 |
+
lat_h,
|
| 205 |
+
lat_w,
|
| 206 |
+
dtype=torch.float32,
|
| 207 |
+
generator=seed_g,
|
| 208 |
+
device=self.device)
|
| 209 |
+
|
| 210 |
+
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
| 211 |
+
msk[:, 1:] = 0
|
| 212 |
+
msk = torch.concat([
|
| 213 |
+
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
| 214 |
+
],
|
| 215 |
+
dim=1)
|
| 216 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 217 |
+
msk = msk.transpose(1, 2)[0]
|
| 218 |
+
|
| 219 |
+
if n_prompt == "":
|
| 220 |
+
n_prompt = self.sample_neg_prompt
|
| 221 |
+
|
| 222 |
+
# preprocess
|
| 223 |
+
if not self.t5_cpu:
|
| 224 |
+
self.text_encoder.model.to(self.device)
|
| 225 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 226 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 227 |
+
if offload_model:
|
| 228 |
+
self.text_encoder.model.cpu()
|
| 229 |
+
else:
|
| 230 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 231 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 232 |
+
context = [t.to(self.device) for t in context]
|
| 233 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 234 |
+
|
| 235 |
+
self.clip.model.to(self.device)
|
| 236 |
+
clip_context = self.clip.visual([img[:, None, :, :]])
|
| 237 |
+
if offload_model:
|
| 238 |
+
self.clip.model.cpu()
|
| 239 |
+
|
| 240 |
+
y = self.vae.encode([
|
| 241 |
+
torch.concat([
|
| 242 |
+
torch.nn.functional.interpolate(
|
| 243 |
+
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
| 244 |
+
0, 1),
|
| 245 |
+
torch.zeros(3, F - 1, h, w)
|
| 246 |
+
],
|
| 247 |
+
dim=1).to(self.device)
|
| 248 |
+
])[0]
|
| 249 |
+
y = torch.concat([msk, y])
|
| 250 |
+
|
| 251 |
+
@contextmanager
|
| 252 |
+
def noop_no_sync():
|
| 253 |
+
yield
|
| 254 |
+
|
| 255 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 256 |
+
|
| 257 |
+
# evaluation mode
|
| 258 |
+
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 259 |
+
|
| 260 |
+
if sample_solver == 'unipc':
|
| 261 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 262 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 263 |
+
shift=1,
|
| 264 |
+
use_dynamic_shifting=False)
|
| 265 |
+
sample_scheduler.set_timesteps(
|
| 266 |
+
sampling_steps, device=self.device, shift=shift)
|
| 267 |
+
timesteps = sample_scheduler.timesteps
|
| 268 |
+
elif sample_solver == 'dpm++':
|
| 269 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 270 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 271 |
+
shift=1,
|
| 272 |
+
use_dynamic_shifting=False)
|
| 273 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 274 |
+
timesteps, _ = retrieve_timesteps(
|
| 275 |
+
sample_scheduler,
|
| 276 |
+
device=self.device,
|
| 277 |
+
sigmas=sampling_sigmas)
|
| 278 |
+
else:
|
| 279 |
+
raise NotImplementedError("Unsupported solver.")
|
| 280 |
+
|
| 281 |
+
# sample videos
|
| 282 |
+
latent = noise
|
| 283 |
+
|
| 284 |
+
arg_c = {
|
| 285 |
+
'context': [context[0]],
|
| 286 |
+
'clip_fea': clip_context,
|
| 287 |
+
'seq_len': max_seq_len,
|
| 288 |
+
'y': [y],
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
arg_null = {
|
| 292 |
+
'context': context_null,
|
| 293 |
+
'clip_fea': clip_context,
|
| 294 |
+
'seq_len': max_seq_len,
|
| 295 |
+
'y': [y],
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
if offload_model:
|
| 299 |
+
torch.cuda.empty_cache()
|
| 300 |
+
|
| 301 |
+
self.model.to(self.device)
|
| 302 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 303 |
+
latent_model_input = [latent.to(self.device)]
|
| 304 |
+
timestep = [t]
|
| 305 |
+
|
| 306 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 307 |
+
|
| 308 |
+
noise_pred_cond = self.model(
|
| 309 |
+
latent_model_input, t=timestep, **arg_c)[0].to(
|
| 310 |
+
torch.device('cpu') if offload_model else self.device)
|
| 311 |
+
if offload_model:
|
| 312 |
+
torch.cuda.empty_cache()
|
| 313 |
+
noise_pred_uncond = self.model(
|
| 314 |
+
latent_model_input, t=timestep, **arg_null)[0].to(
|
| 315 |
+
torch.device('cpu') if offload_model else self.device)
|
| 316 |
+
if offload_model:
|
| 317 |
+
torch.cuda.empty_cache()
|
| 318 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 319 |
+
noise_pred_cond - noise_pred_uncond)
|
| 320 |
+
|
| 321 |
+
latent = latent.to(
|
| 322 |
+
torch.device('cpu') if offload_model else self.device)
|
| 323 |
+
|
| 324 |
+
temp_x0 = sample_scheduler.step(
|
| 325 |
+
noise_pred.unsqueeze(0),
|
| 326 |
+
t,
|
| 327 |
+
latent.unsqueeze(0),
|
| 328 |
+
return_dict=False,
|
| 329 |
+
generator=seed_g)[0]
|
| 330 |
+
latent = temp_x0.squeeze(0)
|
| 331 |
+
|
| 332 |
+
x0 = [latent.to(self.device)]
|
| 333 |
+
del latent_model_input, timestep
|
| 334 |
+
|
| 335 |
+
if offload_model:
|
| 336 |
+
self.model.cpu()
|
| 337 |
+
torch.cuda.empty_cache()
|
| 338 |
+
|
| 339 |
+
if self.rank == 0:
|
| 340 |
+
videos = self.vae.decode(x0)
|
| 341 |
+
|
| 342 |
+
del noise, latent
|
| 343 |
+
del sample_scheduler
|
| 344 |
+
if offload_model:
|
| 345 |
+
gc.collect()
|
| 346 |
+
torch.cuda.synchronize()
|
| 347 |
+
if dist.is_initialized():
|
| 348 |
+
dist.barrier()
|
| 349 |
+
|
| 350 |
+
return videos[0] if self.rank == 0 else None
|
wan/modules/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import flash_attention
|
| 2 |
+
from .model import WanModel
|
| 3 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
| 4 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 5 |
+
from .vace_model import VaceWanModel
|
| 6 |
+
from .vae import WanVAE
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'WanVAE',
|
| 10 |
+
'WanModel',
|
| 11 |
+
'VaceWanModel',
|
| 12 |
+
'T5Model',
|
| 13 |
+
'T5Encoder',
|
| 14 |
+
'T5Decoder',
|
| 15 |
+
'T5EncoderModel',
|
| 16 |
+
'HuggingfaceTokenizer',
|
| 17 |
+
'flash_attention',
|
| 18 |
+
]
|
wan/modules/attention.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
from ..utils.multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
|
| 6 |
+
from xfuser.core.distributed import (
|
| 7 |
+
get_sequence_parallel_rank,
|
| 8 |
+
get_sequence_parallel_world_size,
|
| 9 |
+
get_sp_group,
|
| 10 |
+
)
|
| 11 |
+
import xformers.ops
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import flash_attn_interface
|
| 15 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 16 |
+
except ModuleNotFoundError:
|
| 17 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import flash_attn
|
| 21 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 22 |
+
except ModuleNotFoundError:
|
| 23 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
import warnings
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
'flash_attention',
|
| 29 |
+
'attention',
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def flash_attention(
|
| 34 |
+
q,
|
| 35 |
+
k,
|
| 36 |
+
v,
|
| 37 |
+
q_lens=None,
|
| 38 |
+
k_lens=None,
|
| 39 |
+
dropout_p=0.,
|
| 40 |
+
softmax_scale=None,
|
| 41 |
+
q_scale=None,
|
| 42 |
+
causal=False,
|
| 43 |
+
window_size=(-1, -1),
|
| 44 |
+
deterministic=False,
|
| 45 |
+
dtype=torch.bfloat16,
|
| 46 |
+
version=None,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
q: [B, Lq, Nq, C1].
|
| 50 |
+
k: [B, Lk, Nk, C1].
|
| 51 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 52 |
+
q_lens: [B].
|
| 53 |
+
k_lens: [B].
|
| 54 |
+
dropout_p: float. Dropout probability.
|
| 55 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 56 |
+
causal: bool. Whether to apply causal attention mask.
|
| 57 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 58 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 59 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 60 |
+
"""
|
| 61 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 62 |
+
assert dtype in half_dtypes
|
| 63 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 64 |
+
|
| 65 |
+
# params
|
| 66 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 67 |
+
|
| 68 |
+
def half(x):
|
| 69 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 70 |
+
|
| 71 |
+
# preprocess query
|
| 72 |
+
if q_lens is None:
|
| 73 |
+
q = half(q.flatten(0, 1))
|
| 74 |
+
q_lens = torch.tensor(
|
| 75 |
+
[lq] * b, dtype=torch.int32).to(
|
| 76 |
+
device=q.device, non_blocking=True)
|
| 77 |
+
else:
|
| 78 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 79 |
+
|
| 80 |
+
# preprocess key, value
|
| 81 |
+
if k_lens is None:
|
| 82 |
+
k = half(k.flatten(0, 1))
|
| 83 |
+
v = half(v.flatten(0, 1))
|
| 84 |
+
k_lens = torch.tensor(
|
| 85 |
+
[lk] * b, dtype=torch.int32).to(
|
| 86 |
+
device=k.device, non_blocking=True)
|
| 87 |
+
else:
|
| 88 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 89 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 90 |
+
|
| 91 |
+
q = q.to(v.dtype)
|
| 92 |
+
k = k.to(v.dtype)
|
| 93 |
+
|
| 94 |
+
if q_scale is not None:
|
| 95 |
+
q = q * q_scale
|
| 96 |
+
|
| 97 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 98 |
+
warnings.warn(
|
| 99 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# apply attention
|
| 103 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 104 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 105 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 106 |
+
q=q,
|
| 107 |
+
k=k,
|
| 108 |
+
v=v,
|
| 109 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 110 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 111 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 112 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 113 |
+
seqused_q=None,
|
| 114 |
+
seqused_k=None,
|
| 115 |
+
max_seqlen_q=lq,
|
| 116 |
+
max_seqlen_k=lk,
|
| 117 |
+
softmax_scale=softmax_scale,
|
| 118 |
+
causal=causal,
|
| 119 |
+
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
| 120 |
+
else:
|
| 121 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 122 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 123 |
+
q=q,
|
| 124 |
+
k=k,
|
| 125 |
+
v=v,
|
| 126 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 127 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 128 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 129 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 130 |
+
max_seqlen_q=lq,
|
| 131 |
+
max_seqlen_k=lk,
|
| 132 |
+
dropout_p=dropout_p,
|
| 133 |
+
softmax_scale=softmax_scale,
|
| 134 |
+
causal=causal,
|
| 135 |
+
window_size=window_size,
|
| 136 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 137 |
+
|
| 138 |
+
# output
|
| 139 |
+
return x.type(out_dtype)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def attention(
|
| 143 |
+
q,
|
| 144 |
+
k,
|
| 145 |
+
v,
|
| 146 |
+
q_lens=None,
|
| 147 |
+
k_lens=None,
|
| 148 |
+
dropout_p=0.,
|
| 149 |
+
softmax_scale=None,
|
| 150 |
+
q_scale=None,
|
| 151 |
+
causal=False,
|
| 152 |
+
window_size=(-1, -1),
|
| 153 |
+
deterministic=False,
|
| 154 |
+
dtype=torch.bfloat16,
|
| 155 |
+
fa_version=None,
|
| 156 |
+
):
|
| 157 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 158 |
+
return flash_attention(
|
| 159 |
+
q=q,
|
| 160 |
+
k=k,
|
| 161 |
+
v=v,
|
| 162 |
+
q_lens=q_lens,
|
| 163 |
+
k_lens=k_lens,
|
| 164 |
+
dropout_p=dropout_p,
|
| 165 |
+
softmax_scale=softmax_scale,
|
| 166 |
+
q_scale=q_scale,
|
| 167 |
+
causal=causal,
|
| 168 |
+
window_size=window_size,
|
| 169 |
+
deterministic=deterministic,
|
| 170 |
+
dtype=dtype,
|
| 171 |
+
version=fa_version,
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
if q_lens is not None or k_lens is not None:
|
| 175 |
+
warnings.warn(
|
| 176 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 177 |
+
)
|
| 178 |
+
attn_mask = None
|
| 179 |
+
|
| 180 |
+
q = q.transpose(1, 2).to(dtype)
|
| 181 |
+
k = k.transpose(1, 2).to(dtype)
|
| 182 |
+
v = v.transpose(1, 2).to(dtype)
|
| 183 |
+
|
| 184 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 185 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
| 186 |
+
|
| 187 |
+
out = out.transpose(1, 2).contiguous()
|
| 188 |
+
return out
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class SingleStreamAttention(nn.Module):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
dim: int,
|
| 195 |
+
encoder_hidden_states_dim: int,
|
| 196 |
+
num_heads: int,
|
| 197 |
+
qkv_bias: bool,
|
| 198 |
+
qk_norm: bool,
|
| 199 |
+
norm_layer: nn.Module,
|
| 200 |
+
attn_drop: float = 0.0,
|
| 201 |
+
proj_drop: float = 0.0,
|
| 202 |
+
eps: float = 1e-6,
|
| 203 |
+
) -> None:
|
| 204 |
+
super().__init__()
|
| 205 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 206 |
+
self.dim = dim
|
| 207 |
+
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
| 208 |
+
self.num_heads = num_heads
|
| 209 |
+
self.head_dim = dim // num_heads
|
| 210 |
+
self.scale = self.head_dim**-0.5
|
| 211 |
+
self.qk_norm = qk_norm
|
| 212 |
+
|
| 213 |
+
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
|
| 214 |
+
|
| 215 |
+
self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 216 |
+
self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
|
| 217 |
+
|
| 218 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 219 |
+
self.proj = nn.Linear(dim, dim)
|
| 220 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 221 |
+
|
| 222 |
+
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
|
| 223 |
+
|
| 224 |
+
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 225 |
+
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 226 |
+
|
| 227 |
+
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
|
| 228 |
+
|
| 229 |
+
N_t, N_h, N_w = shape
|
| 230 |
+
if not enable_sp:
|
| 231 |
+
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
| 232 |
+
|
| 233 |
+
# get q for hidden_state
|
| 234 |
+
B, N, C = x.shape
|
| 235 |
+
q = self.q_linear(x)
|
| 236 |
+
q_shape = (B, N, self.num_heads, self.head_dim)
|
| 237 |
+
q = q.view(q_shape).permute((0, 2, 1, 3))
|
| 238 |
+
|
| 239 |
+
if self.qk_norm:
|
| 240 |
+
q = self.q_norm(q)
|
| 241 |
+
|
| 242 |
+
# get kv from encoder_hidden_states
|
| 243 |
+
_, N_a, _ = encoder_hidden_states.shape
|
| 244 |
+
encoder_kv = self.kv_linear(encoder_hidden_states)
|
| 245 |
+
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
| 246 |
+
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
| 247 |
+
encoder_k, encoder_v = encoder_kv.unbind(0)
|
| 248 |
+
|
| 249 |
+
if self.qk_norm:
|
| 250 |
+
encoder_k = self.add_k_norm(encoder_k)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
q = rearrange(q, "B H M K -> B M H K")
|
| 254 |
+
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
| 255 |
+
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
| 256 |
+
|
| 257 |
+
if enable_sp:
|
| 258 |
+
# context parallel
|
| 259 |
+
sp_size = get_sequence_parallel_world_size()
|
| 260 |
+
sp_rank = get_sequence_parallel_rank()
|
| 261 |
+
visual_seqlen, _ = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank)
|
| 262 |
+
assert kv_seq is not None, f"kv_seq should not be None."
|
| 263 |
+
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq)
|
| 264 |
+
else:
|
| 265 |
+
attn_bias = None
|
| 266 |
+
x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
|
| 267 |
+
x = rearrange(x, "B M H K -> B H M K")
|
| 268 |
+
|
| 269 |
+
# linear transform
|
| 270 |
+
x_output_shape = (B, N, C)
|
| 271 |
+
x = x.transpose(1, 2)
|
| 272 |
+
x = x.reshape(x_output_shape)
|
| 273 |
+
x = self.proj(x)
|
| 274 |
+
x = self.proj_drop(x)
|
| 275 |
+
|
| 276 |
+
if not enable_sp:
|
| 277 |
+
# reshape x to origin shape
|
| 278 |
+
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
| 279 |
+
|
| 280 |
+
return x
|
| 281 |
+
|
| 282 |
+
class SingleStreamMutiAttention(SingleStreamAttention):
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
dim: int,
|
| 286 |
+
encoder_hidden_states_dim: int,
|
| 287 |
+
num_heads: int,
|
| 288 |
+
qkv_bias: bool,
|
| 289 |
+
qk_norm: bool,
|
| 290 |
+
norm_layer: nn.Module,
|
| 291 |
+
attn_drop: float = 0.0,
|
| 292 |
+
proj_drop: float = 0.0,
|
| 293 |
+
eps: float = 1e-6,
|
| 294 |
+
class_range: int = 24,
|
| 295 |
+
class_interval: int = 4,
|
| 296 |
+
) -> None:
|
| 297 |
+
super().__init__(
|
| 298 |
+
dim=dim,
|
| 299 |
+
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
| 300 |
+
num_heads=num_heads,
|
| 301 |
+
qkv_bias=qkv_bias,
|
| 302 |
+
qk_norm=qk_norm,
|
| 303 |
+
norm_layer=norm_layer,
|
| 304 |
+
attn_drop=attn_drop,
|
| 305 |
+
proj_drop=proj_drop,
|
| 306 |
+
eps=eps,
|
| 307 |
+
)
|
| 308 |
+
self.class_interval = class_interval
|
| 309 |
+
self.class_range = class_range
|
| 310 |
+
self.rope_h1 = (0, self.class_interval)
|
| 311 |
+
self.rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
| 312 |
+
self.rope_bak = int(self.class_range // 2)
|
| 313 |
+
|
| 314 |
+
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
| 315 |
+
|
| 316 |
+
def forward(self,
|
| 317 |
+
x: torch.Tensor,
|
| 318 |
+
encoder_hidden_states: torch.Tensor,
|
| 319 |
+
shape=None,
|
| 320 |
+
x_ref_attn_map=None,
|
| 321 |
+
human_num=None) -> torch.Tensor:
|
| 322 |
+
|
| 323 |
+
encoder_hidden_states = encoder_hidden_states.squeeze(0)
|
| 324 |
+
if human_num == 1:
|
| 325 |
+
return super().forward(x, encoder_hidden_states, shape)
|
| 326 |
+
|
| 327 |
+
N_t, _, _ = shape
|
| 328 |
+
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
| 329 |
+
|
| 330 |
+
# get q for hidden_state
|
| 331 |
+
B, N, C = x.shape
|
| 332 |
+
q = self.q_linear(x)
|
| 333 |
+
q_shape = (B, N, self.num_heads, self.head_dim)
|
| 334 |
+
q = q.view(q_shape).permute((0, 2, 1, 3))
|
| 335 |
+
|
| 336 |
+
if self.qk_norm:
|
| 337 |
+
q = self.q_norm(q)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
| 341 |
+
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
| 342 |
+
max_min_values = torch.cat([max_values, min_values], dim=2)
|
| 343 |
+
|
| 344 |
+
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
| 345 |
+
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
| 346 |
+
|
| 347 |
+
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
|
| 348 |
+
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
|
| 349 |
+
back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
|
| 350 |
+
max_indices = x_ref_attn_map.argmax(dim=0)
|
| 351 |
+
normalized_map = torch.stack([human1, human2, back], dim=1)
|
| 352 |
+
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
|
| 353 |
+
|
| 354 |
+
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
| 355 |
+
q = self.rope_1d(q, normalized_pos)
|
| 356 |
+
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
| 357 |
+
|
| 358 |
+
_, N_a, _ = encoder_hidden_states.shape
|
| 359 |
+
encoder_kv = self.kv_linear(encoder_hidden_states)
|
| 360 |
+
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
|
| 361 |
+
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
|
| 362 |
+
encoder_k, encoder_v = encoder_kv.unbind(0)
|
| 363 |
+
|
| 364 |
+
if self.qk_norm:
|
| 365 |
+
encoder_k = self.add_k_norm(encoder_k)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
|
| 369 |
+
per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
|
| 370 |
+
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
|
| 371 |
+
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
|
| 372 |
+
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
| 373 |
+
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
| 374 |
+
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
q = rearrange(q, "B H M K -> B M H K")
|
| 378 |
+
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
| 379 |
+
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
| 380 |
+
x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
|
| 381 |
+
x = rearrange(x, "B M H K -> B H M K")
|
| 382 |
+
|
| 383 |
+
# linear transform
|
| 384 |
+
x_output_shape = (B, N, C)
|
| 385 |
+
x = x.transpose(1, 2)
|
| 386 |
+
x = x.reshape(x_output_shape)
|
| 387 |
+
x = self.proj(x)
|
| 388 |
+
x = self.proj_drop(x)
|
| 389 |
+
|
| 390 |
+
# reshape x to origin shape
|
| 391 |
+
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
| 392 |
+
|
| 393 |
+
return x
|
wan/modules/clip.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
|
| 11 |
+
from .attention import flash_attention
|
| 12 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 13 |
+
from .xlm_roberta import XLMRoberta
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'XLMRobertaCLIP',
|
| 17 |
+
'clip_xlm_roberta_vit_h_14',
|
| 18 |
+
'CLIPModel',
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def pos_interpolate(pos, seq_len):
|
| 23 |
+
if pos.size(1) == seq_len:
|
| 24 |
+
return pos
|
| 25 |
+
else:
|
| 26 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 27 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 28 |
+
n = pos.size(1) - src_grid * src_grid
|
| 29 |
+
return torch.cat([
|
| 30 |
+
pos[:, :n],
|
| 31 |
+
F.interpolate(
|
| 32 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
| 33 |
+
0, 3, 1, 2),
|
| 34 |
+
size=(tar_grid, tar_grid),
|
| 35 |
+
mode='bicubic',
|
| 36 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
| 37 |
+
],
|
| 38 |
+
dim=1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class QuickGELU(nn.Module):
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return x * torch.sigmoid(1.702 * x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LayerNorm(nn.LayerNorm):
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return super().forward(x.float()).type_as(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SelfAttention(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self,
|
| 56 |
+
dim,
|
| 57 |
+
num_heads,
|
| 58 |
+
causal=False,
|
| 59 |
+
attn_dropout=0.0,
|
| 60 |
+
proj_dropout=0.0):
|
| 61 |
+
assert dim % num_heads == 0
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
self.head_dim = dim // num_heads
|
| 66 |
+
self.causal = causal
|
| 67 |
+
self.attn_dropout = attn_dropout
|
| 68 |
+
self.proj_dropout = proj_dropout
|
| 69 |
+
|
| 70 |
+
# layers
|
| 71 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 72 |
+
self.proj = nn.Linear(dim, dim)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
"""
|
| 76 |
+
x: [B, L, C].
|
| 77 |
+
"""
|
| 78 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 79 |
+
|
| 80 |
+
# compute query, key, value
|
| 81 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
| 82 |
+
|
| 83 |
+
# compute attention
|
| 84 |
+
p = self.attn_dropout if self.training else 0.0
|
| 85 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
| 86 |
+
x = x.reshape(b, s, c)
|
| 87 |
+
|
| 88 |
+
# output
|
| 89 |
+
x = self.proj(x)
|
| 90 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SwiGLU(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self, dim, mid_dim):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.mid_dim = mid_dim
|
| 100 |
+
|
| 101 |
+
# layers
|
| 102 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 103 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 104 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 108 |
+
x = self.fc3(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AttentionBlock(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self,
|
| 115 |
+
dim,
|
| 116 |
+
mlp_ratio,
|
| 117 |
+
num_heads,
|
| 118 |
+
post_norm=False,
|
| 119 |
+
causal=False,
|
| 120 |
+
activation='quick_gelu',
|
| 121 |
+
attn_dropout=0.0,
|
| 122 |
+
proj_dropout=0.0,
|
| 123 |
+
norm_eps=1e-5):
|
| 124 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dim = dim
|
| 127 |
+
self.mlp_ratio = mlp_ratio
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.post_norm = post_norm
|
| 130 |
+
self.causal = causal
|
| 131 |
+
self.norm_eps = norm_eps
|
| 132 |
+
|
| 133 |
+
# layers
|
| 134 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 135 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
| 136 |
+
proj_dropout)
|
| 137 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 138 |
+
if activation == 'swi_glu':
|
| 139 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 140 |
+
else:
|
| 141 |
+
self.mlp = nn.Sequential(
|
| 142 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 143 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 144 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
if self.post_norm:
|
| 148 |
+
x = x + self.norm1(self.attn(x))
|
| 149 |
+
x = x + self.norm2(self.mlp(x))
|
| 150 |
+
else:
|
| 151 |
+
x = x + self.attn(self.norm1(x))
|
| 152 |
+
x = x + self.mlp(self.norm2(x))
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class AttentionPool(nn.Module):
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
dim,
|
| 160 |
+
mlp_ratio,
|
| 161 |
+
num_heads,
|
| 162 |
+
activation='gelu',
|
| 163 |
+
proj_dropout=0.0,
|
| 164 |
+
norm_eps=1e-5):
|
| 165 |
+
assert dim % num_heads == 0
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.dim = dim
|
| 168 |
+
self.mlp_ratio = mlp_ratio
|
| 169 |
+
self.num_heads = num_heads
|
| 170 |
+
self.head_dim = dim // num_heads
|
| 171 |
+
self.proj_dropout = proj_dropout
|
| 172 |
+
self.norm_eps = norm_eps
|
| 173 |
+
|
| 174 |
+
# layers
|
| 175 |
+
gain = 1.0 / math.sqrt(dim)
|
| 176 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 177 |
+
self.to_q = nn.Linear(dim, dim)
|
| 178 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 179 |
+
self.proj = nn.Linear(dim, dim)
|
| 180 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 181 |
+
self.mlp = nn.Sequential(
|
| 182 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 183 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 184 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
"""
|
| 188 |
+
x: [B, L, C].
|
| 189 |
+
"""
|
| 190 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 191 |
+
|
| 192 |
+
# compute query, key, value
|
| 193 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
| 194 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
| 195 |
+
|
| 196 |
+
# compute attention
|
| 197 |
+
x = flash_attention(q, k, v, version=2)
|
| 198 |
+
x = x.reshape(b, 1, c)
|
| 199 |
+
|
| 200 |
+
# output
|
| 201 |
+
x = self.proj(x)
|
| 202 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 203 |
+
|
| 204 |
+
# mlp
|
| 205 |
+
x = x + self.mlp(self.norm(x))
|
| 206 |
+
return x[:, 0]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class VisionTransformer(nn.Module):
|
| 210 |
+
|
| 211 |
+
def __init__(self,
|
| 212 |
+
image_size=224,
|
| 213 |
+
patch_size=16,
|
| 214 |
+
dim=768,
|
| 215 |
+
mlp_ratio=4,
|
| 216 |
+
out_dim=512,
|
| 217 |
+
num_heads=12,
|
| 218 |
+
num_layers=12,
|
| 219 |
+
pool_type='token',
|
| 220 |
+
pre_norm=True,
|
| 221 |
+
post_norm=False,
|
| 222 |
+
activation='quick_gelu',
|
| 223 |
+
attn_dropout=0.0,
|
| 224 |
+
proj_dropout=0.0,
|
| 225 |
+
embedding_dropout=0.0,
|
| 226 |
+
norm_eps=1e-5):
|
| 227 |
+
if image_size % patch_size != 0:
|
| 228 |
+
print(
|
| 229 |
+
'[WARNING] image_size is not divisible by patch_size',
|
| 230 |
+
flush=True)
|
| 231 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
| 232 |
+
out_dim = out_dim or dim
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.image_size = image_size
|
| 235 |
+
self.patch_size = patch_size
|
| 236 |
+
self.num_patches = (image_size // patch_size)**2
|
| 237 |
+
self.dim = dim
|
| 238 |
+
self.mlp_ratio = mlp_ratio
|
| 239 |
+
self.out_dim = out_dim
|
| 240 |
+
self.num_heads = num_heads
|
| 241 |
+
self.num_layers = num_layers
|
| 242 |
+
self.pool_type = pool_type
|
| 243 |
+
self.post_norm = post_norm
|
| 244 |
+
self.norm_eps = norm_eps
|
| 245 |
+
|
| 246 |
+
# embeddings
|
| 247 |
+
gain = 1.0 / math.sqrt(dim)
|
| 248 |
+
self.patch_embedding = nn.Conv2d(
|
| 249 |
+
3,
|
| 250 |
+
dim,
|
| 251 |
+
kernel_size=patch_size,
|
| 252 |
+
stride=patch_size,
|
| 253 |
+
bias=not pre_norm)
|
| 254 |
+
if pool_type in ('token', 'token_fc'):
|
| 255 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 256 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
| 257 |
+
1, self.num_patches +
|
| 258 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
| 259 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 260 |
+
|
| 261 |
+
# transformer
|
| 262 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 263 |
+
self.transformer = nn.Sequential(*[
|
| 264 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
| 265 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
| 266 |
+
for _ in range(num_layers)
|
| 267 |
+
])
|
| 268 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 269 |
+
|
| 270 |
+
# head
|
| 271 |
+
if pool_type == 'token':
|
| 272 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 273 |
+
elif pool_type == 'token_fc':
|
| 274 |
+
self.head = nn.Linear(dim, out_dim)
|
| 275 |
+
elif pool_type == 'attn_pool':
|
| 276 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
| 277 |
+
proj_dropout, norm_eps)
|
| 278 |
+
|
| 279 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 280 |
+
b = x.size(0)
|
| 281 |
+
|
| 282 |
+
# embeddings
|
| 283 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 284 |
+
if self.pool_type in ('token', 'token_fc'):
|
| 285 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
| 286 |
+
if interpolation:
|
| 287 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 288 |
+
else:
|
| 289 |
+
e = self.pos_embedding
|
| 290 |
+
x = self.dropout(x + e)
|
| 291 |
+
if self.pre_norm is not None:
|
| 292 |
+
x = self.pre_norm(x)
|
| 293 |
+
|
| 294 |
+
# transformer
|
| 295 |
+
if use_31_block:
|
| 296 |
+
x = self.transformer[:-1](x)
|
| 297 |
+
return x
|
| 298 |
+
else:
|
| 299 |
+
x = self.transformer(x)
|
| 300 |
+
return x
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 304 |
+
|
| 305 |
+
def __init__(self, **kwargs):
|
| 306 |
+
self.out_dim = kwargs.pop('out_dim')
|
| 307 |
+
super().__init__(**kwargs)
|
| 308 |
+
|
| 309 |
+
# head
|
| 310 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 311 |
+
self.head = nn.Sequential(
|
| 312 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
| 313 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
| 314 |
+
|
| 315 |
+
def forward(self, ids):
|
| 316 |
+
# xlm-roberta
|
| 317 |
+
x = super().forward(ids)
|
| 318 |
+
|
| 319 |
+
# average pooling
|
| 320 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 321 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 322 |
+
|
| 323 |
+
# head
|
| 324 |
+
x = self.head(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class XLMRobertaCLIP(nn.Module):
|
| 329 |
+
|
| 330 |
+
def __init__(self,
|
| 331 |
+
embed_dim=1024,
|
| 332 |
+
image_size=224,
|
| 333 |
+
patch_size=14,
|
| 334 |
+
vision_dim=1280,
|
| 335 |
+
vision_mlp_ratio=4,
|
| 336 |
+
vision_heads=16,
|
| 337 |
+
vision_layers=32,
|
| 338 |
+
vision_pool='token',
|
| 339 |
+
vision_pre_norm=True,
|
| 340 |
+
vision_post_norm=False,
|
| 341 |
+
activation='gelu',
|
| 342 |
+
vocab_size=250002,
|
| 343 |
+
max_text_len=514,
|
| 344 |
+
type_size=1,
|
| 345 |
+
pad_id=1,
|
| 346 |
+
text_dim=1024,
|
| 347 |
+
text_heads=16,
|
| 348 |
+
text_layers=24,
|
| 349 |
+
text_post_norm=True,
|
| 350 |
+
text_dropout=0.1,
|
| 351 |
+
attn_dropout=0.0,
|
| 352 |
+
proj_dropout=0.0,
|
| 353 |
+
embedding_dropout=0.0,
|
| 354 |
+
norm_eps=1e-5):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.embed_dim = embed_dim
|
| 357 |
+
self.image_size = image_size
|
| 358 |
+
self.patch_size = patch_size
|
| 359 |
+
self.vision_dim = vision_dim
|
| 360 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 361 |
+
self.vision_heads = vision_heads
|
| 362 |
+
self.vision_layers = vision_layers
|
| 363 |
+
self.vision_pre_norm = vision_pre_norm
|
| 364 |
+
self.vision_post_norm = vision_post_norm
|
| 365 |
+
self.activation = activation
|
| 366 |
+
self.vocab_size = vocab_size
|
| 367 |
+
self.max_text_len = max_text_len
|
| 368 |
+
self.type_size = type_size
|
| 369 |
+
self.pad_id = pad_id
|
| 370 |
+
self.text_dim = text_dim
|
| 371 |
+
self.text_heads = text_heads
|
| 372 |
+
self.text_layers = text_layers
|
| 373 |
+
self.text_post_norm = text_post_norm
|
| 374 |
+
self.norm_eps = norm_eps
|
| 375 |
+
|
| 376 |
+
# models
|
| 377 |
+
self.visual = VisionTransformer(
|
| 378 |
+
image_size=image_size,
|
| 379 |
+
patch_size=patch_size,
|
| 380 |
+
dim=vision_dim,
|
| 381 |
+
mlp_ratio=vision_mlp_ratio,
|
| 382 |
+
out_dim=embed_dim,
|
| 383 |
+
num_heads=vision_heads,
|
| 384 |
+
num_layers=vision_layers,
|
| 385 |
+
pool_type=vision_pool,
|
| 386 |
+
pre_norm=vision_pre_norm,
|
| 387 |
+
post_norm=vision_post_norm,
|
| 388 |
+
activation=activation,
|
| 389 |
+
attn_dropout=attn_dropout,
|
| 390 |
+
proj_dropout=proj_dropout,
|
| 391 |
+
embedding_dropout=embedding_dropout,
|
| 392 |
+
norm_eps=norm_eps)
|
| 393 |
+
self.textual = XLMRobertaWithHead(
|
| 394 |
+
vocab_size=vocab_size,
|
| 395 |
+
max_seq_len=max_text_len,
|
| 396 |
+
type_size=type_size,
|
| 397 |
+
pad_id=pad_id,
|
| 398 |
+
dim=text_dim,
|
| 399 |
+
out_dim=embed_dim,
|
| 400 |
+
num_heads=text_heads,
|
| 401 |
+
num_layers=text_layers,
|
| 402 |
+
post_norm=text_post_norm,
|
| 403 |
+
dropout=text_dropout)
|
| 404 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 405 |
+
|
| 406 |
+
def forward(self, imgs, txt_ids):
|
| 407 |
+
"""
|
| 408 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 409 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 410 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 411 |
+
txt_ids: [B, L] of torch.long.
|
| 412 |
+
Encoded by data.CLIPTokenizer.
|
| 413 |
+
"""
|
| 414 |
+
xi = self.visual(imgs)
|
| 415 |
+
xt = self.textual(txt_ids)
|
| 416 |
+
return xi, xt
|
| 417 |
+
|
| 418 |
+
def param_groups(self):
|
| 419 |
+
groups = [{
|
| 420 |
+
'params': [
|
| 421 |
+
p for n, p in self.named_parameters()
|
| 422 |
+
if 'norm' in n or n.endswith('bias')
|
| 423 |
+
],
|
| 424 |
+
'weight_decay': 0.0
|
| 425 |
+
}, {
|
| 426 |
+
'params': [
|
| 427 |
+
p for n, p in self.named_parameters()
|
| 428 |
+
if not ('norm' in n or n.endswith('bias'))
|
| 429 |
+
]
|
| 430 |
+
}]
|
| 431 |
+
return groups
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _clip(pretrained=False,
|
| 435 |
+
pretrained_name=None,
|
| 436 |
+
model_cls=XLMRobertaCLIP,
|
| 437 |
+
return_transforms=False,
|
| 438 |
+
return_tokenizer=False,
|
| 439 |
+
tokenizer_padding='eos',
|
| 440 |
+
dtype=torch.float32,
|
| 441 |
+
device='cpu',
|
| 442 |
+
**kwargs):
|
| 443 |
+
# init a model on device
|
| 444 |
+
with torch.device(device):
|
| 445 |
+
model = model_cls(**kwargs)
|
| 446 |
+
|
| 447 |
+
# set device
|
| 448 |
+
model = model.to(dtype=dtype, device=device)
|
| 449 |
+
output = (model,)
|
| 450 |
+
|
| 451 |
+
# init transforms
|
| 452 |
+
if return_transforms:
|
| 453 |
+
# mean and std
|
| 454 |
+
if 'siglip' in pretrained_name.lower():
|
| 455 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 456 |
+
else:
|
| 457 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 458 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 459 |
+
|
| 460 |
+
# transforms
|
| 461 |
+
transforms = T.Compose([
|
| 462 |
+
T.Resize((model.image_size, model.image_size),
|
| 463 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 464 |
+
T.ToTensor(),
|
| 465 |
+
T.Normalize(mean=mean, std=std)
|
| 466 |
+
])
|
| 467 |
+
output += (transforms,)
|
| 468 |
+
return output[0] if len(output) == 1 else output
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def clip_xlm_roberta_vit_h_14(
|
| 472 |
+
pretrained=False,
|
| 473 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
| 474 |
+
**kwargs):
|
| 475 |
+
cfg = dict(
|
| 476 |
+
embed_dim=1024,
|
| 477 |
+
image_size=224,
|
| 478 |
+
patch_size=14,
|
| 479 |
+
vision_dim=1280,
|
| 480 |
+
vision_mlp_ratio=4,
|
| 481 |
+
vision_heads=16,
|
| 482 |
+
vision_layers=32,
|
| 483 |
+
vision_pool='token',
|
| 484 |
+
activation='gelu',
|
| 485 |
+
vocab_size=250002,
|
| 486 |
+
max_text_len=514,
|
| 487 |
+
type_size=1,
|
| 488 |
+
pad_id=1,
|
| 489 |
+
text_dim=1024,
|
| 490 |
+
text_heads=16,
|
| 491 |
+
text_layers=24,
|
| 492 |
+
text_post_norm=True,
|
| 493 |
+
text_dropout=0.1,
|
| 494 |
+
attn_dropout=0.0,
|
| 495 |
+
proj_dropout=0.0,
|
| 496 |
+
embedding_dropout=0.0)
|
| 497 |
+
cfg.update(**kwargs)
|
| 498 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class CLIPModel:
|
| 502 |
+
|
| 503 |
+
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
|
| 504 |
+
self.dtype = dtype
|
| 505 |
+
self.device = device
|
| 506 |
+
self.checkpoint_path = checkpoint_path
|
| 507 |
+
self.tokenizer_path = tokenizer_path
|
| 508 |
+
|
| 509 |
+
# init model
|
| 510 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 511 |
+
pretrained=False,
|
| 512 |
+
return_transforms=True,
|
| 513 |
+
return_tokenizer=False,
|
| 514 |
+
dtype=dtype,
|
| 515 |
+
device=device)
|
| 516 |
+
self.model = self.model.eval().requires_grad_(False)
|
| 517 |
+
logging.info(f'loading {checkpoint_path}')
|
| 518 |
+
self.model.load_state_dict(
|
| 519 |
+
torch.load(checkpoint_path, map_location='cpu'))
|
| 520 |
+
|
| 521 |
+
# init tokenizer
|
| 522 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 523 |
+
name=tokenizer_path,
|
| 524 |
+
seq_len=self.model.max_text_len - 2,
|
| 525 |
+
clean='whitespace')
|
| 526 |
+
|
| 527 |
+
def visual(self, videos):
|
| 528 |
+
# preprocess
|
| 529 |
+
size = (self.model.image_size,) * 2
|
| 530 |
+
videos = torch.cat([
|
| 531 |
+
F.interpolate(
|
| 532 |
+
u.transpose(0, 1),
|
| 533 |
+
size=size,
|
| 534 |
+
mode='bicubic',
|
| 535 |
+
align_corners=False) for u in videos
|
| 536 |
+
])
|
| 537 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 538 |
+
|
| 539 |
+
# forward
|
| 540 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 541 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 542 |
+
return out
|
wan/modules/model.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
+
|
| 10 |
+
from .attention import flash_attention
|
| 11 |
+
|
| 12 |
+
__all__ = ['WanModel']
|
| 13 |
+
|
| 14 |
+
T5_CONTEXT_TOKEN_NUMBER = 512
|
| 15 |
+
FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 19 |
+
# preprocess
|
| 20 |
+
assert dim % 2 == 0
|
| 21 |
+
half = dim // 2
|
| 22 |
+
position = position.type(torch.float64)
|
| 23 |
+
|
| 24 |
+
# calculation
|
| 25 |
+
sinusoid = torch.outer(
|
| 26 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 27 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@amp.autocast(enabled=False)
|
| 32 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 33 |
+
assert dim % 2 == 0
|
| 34 |
+
freqs = torch.outer(
|
| 35 |
+
torch.arange(max_seq_len),
|
| 36 |
+
1.0 / torch.pow(theta,
|
| 37 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 38 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 39 |
+
return freqs
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@amp.autocast(enabled=False)
|
| 43 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 44 |
+
n, c = x.size(2), x.size(3) // 2
|
| 45 |
+
|
| 46 |
+
# split freqs
|
| 47 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 48 |
+
|
| 49 |
+
# loop over samples
|
| 50 |
+
output = []
|
| 51 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 52 |
+
seq_len = f * h * w
|
| 53 |
+
|
| 54 |
+
# precompute multipliers
|
| 55 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
| 56 |
+
seq_len, n, -1, 2))
|
| 57 |
+
freqs_i = torch.cat([
|
| 58 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 59 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 60 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 61 |
+
],
|
| 62 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 63 |
+
|
| 64 |
+
# apply rotary embedding
|
| 65 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 66 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 67 |
+
|
| 68 |
+
# append to collection
|
| 69 |
+
output.append(x_i)
|
| 70 |
+
return torch.stack(output).float()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class WanRMSNorm(nn.Module):
|
| 74 |
+
|
| 75 |
+
def __init__(self, dim, eps=1e-5):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.dim = dim
|
| 78 |
+
self.eps = eps
|
| 79 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
r"""
|
| 83 |
+
Args:
|
| 84 |
+
x(Tensor): Shape [B, L, C]
|
| 85 |
+
"""
|
| 86 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 87 |
+
|
| 88 |
+
def _norm(self, x):
|
| 89 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 93 |
+
|
| 94 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 95 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
r"""
|
| 99 |
+
Args:
|
| 100 |
+
x(Tensor): Shape [B, L, C]
|
| 101 |
+
"""
|
| 102 |
+
return super().forward(x.float()).type_as(x)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class WanSelfAttention(nn.Module):
|
| 106 |
+
|
| 107 |
+
def __init__(self,
|
| 108 |
+
dim,
|
| 109 |
+
num_heads,
|
| 110 |
+
window_size=(-1, -1),
|
| 111 |
+
qk_norm=True,
|
| 112 |
+
eps=1e-6):
|
| 113 |
+
assert dim % num_heads == 0
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.dim = dim
|
| 116 |
+
self.num_heads = num_heads
|
| 117 |
+
self.head_dim = dim // num_heads
|
| 118 |
+
self.window_size = window_size
|
| 119 |
+
self.qk_norm = qk_norm
|
| 120 |
+
self.eps = eps
|
| 121 |
+
|
| 122 |
+
# layers
|
| 123 |
+
self.q = nn.Linear(dim, dim)
|
| 124 |
+
self.k = nn.Linear(dim, dim)
|
| 125 |
+
self.v = nn.Linear(dim, dim)
|
| 126 |
+
self.o = nn.Linear(dim, dim)
|
| 127 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 128 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 129 |
+
|
| 130 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 131 |
+
r"""
|
| 132 |
+
Args:
|
| 133 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 134 |
+
seq_lens(Tensor): Shape [B]
|
| 135 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 136 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 137 |
+
"""
|
| 138 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 139 |
+
|
| 140 |
+
# query, key, value function
|
| 141 |
+
def qkv_fn(x):
|
| 142 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 143 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 144 |
+
v = self.v(x).view(b, s, n, d)
|
| 145 |
+
return q, k, v
|
| 146 |
+
|
| 147 |
+
q, k, v = qkv_fn(x)
|
| 148 |
+
|
| 149 |
+
x = flash_attention(
|
| 150 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 151 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 152 |
+
v=v,
|
| 153 |
+
k_lens=seq_lens,
|
| 154 |
+
window_size=self.window_size)
|
| 155 |
+
|
| 156 |
+
# output
|
| 157 |
+
x = x.flatten(2)
|
| 158 |
+
x = self.o(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
| 163 |
+
|
| 164 |
+
def forward(self, x, context, context_lens):
|
| 165 |
+
r"""
|
| 166 |
+
Args:
|
| 167 |
+
x(Tensor): Shape [B, L1, C]
|
| 168 |
+
context(Tensor): Shape [B, L2, C]
|
| 169 |
+
context_lens(Tensor): Shape [B]
|
| 170 |
+
"""
|
| 171 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 172 |
+
|
| 173 |
+
# compute query, key, value
|
| 174 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 175 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 176 |
+
v = self.v(context).view(b, -1, n, d)
|
| 177 |
+
|
| 178 |
+
# compute attention
|
| 179 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 180 |
+
|
| 181 |
+
# output
|
| 182 |
+
x = x.flatten(2)
|
| 183 |
+
x = self.o(x)
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 188 |
+
|
| 189 |
+
def __init__(self,
|
| 190 |
+
dim,
|
| 191 |
+
num_heads,
|
| 192 |
+
window_size=(-1, -1),
|
| 193 |
+
qk_norm=True,
|
| 194 |
+
eps=1e-6):
|
| 195 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 196 |
+
|
| 197 |
+
self.k_img = nn.Linear(dim, dim)
|
| 198 |
+
self.v_img = nn.Linear(dim, dim)
|
| 199 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
| 200 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 201 |
+
|
| 202 |
+
def forward(self, x, context, context_lens):
|
| 203 |
+
r"""
|
| 204 |
+
Args:
|
| 205 |
+
x(Tensor): Shape [B, L1, C]
|
| 206 |
+
context(Tensor): Shape [B, L2, C]
|
| 207 |
+
context_lens(Tensor): Shape [B]
|
| 208 |
+
"""
|
| 209 |
+
image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
|
| 210 |
+
context_img = context[:, :image_context_length]
|
| 211 |
+
context = context[:, image_context_length:]
|
| 212 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 213 |
+
|
| 214 |
+
# compute query, key, value
|
| 215 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 216 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 217 |
+
v = self.v(context).view(b, -1, n, d)
|
| 218 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
| 219 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
| 220 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| 221 |
+
# compute attention
|
| 222 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 223 |
+
|
| 224 |
+
# output
|
| 225 |
+
x = x.flatten(2)
|
| 226 |
+
img_x = img_x.flatten(2)
|
| 227 |
+
x = x + img_x
|
| 228 |
+
x = self.o(x)
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
WAN_CROSSATTENTION_CLASSES = {
|
| 233 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
| 234 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class WanAttentionBlock(nn.Module):
|
| 239 |
+
|
| 240 |
+
def __init__(self,
|
| 241 |
+
cross_attn_type,
|
| 242 |
+
dim,
|
| 243 |
+
ffn_dim,
|
| 244 |
+
num_heads,
|
| 245 |
+
window_size=(-1, -1),
|
| 246 |
+
qk_norm=True,
|
| 247 |
+
cross_attn_norm=False,
|
| 248 |
+
eps=1e-6):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.dim = dim
|
| 251 |
+
self.ffn_dim = ffn_dim
|
| 252 |
+
self.num_heads = num_heads
|
| 253 |
+
self.window_size = window_size
|
| 254 |
+
self.qk_norm = qk_norm
|
| 255 |
+
self.cross_attn_norm = cross_attn_norm
|
| 256 |
+
self.eps = eps
|
| 257 |
+
|
| 258 |
+
# layers
|
| 259 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 260 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 261 |
+
eps)
|
| 262 |
+
self.norm3 = WanLayerNorm(
|
| 263 |
+
dim, eps,
|
| 264 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 265 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 266 |
+
num_heads,
|
| 267 |
+
(-1, -1),
|
| 268 |
+
qk_norm,
|
| 269 |
+
eps)
|
| 270 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 271 |
+
self.ffn = nn.Sequential(
|
| 272 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 273 |
+
nn.Linear(ffn_dim, dim))
|
| 274 |
+
|
| 275 |
+
# modulation
|
| 276 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
x,
|
| 281 |
+
e,
|
| 282 |
+
seq_lens,
|
| 283 |
+
grid_sizes,
|
| 284 |
+
freqs,
|
| 285 |
+
context,
|
| 286 |
+
context_lens,
|
| 287 |
+
):
|
| 288 |
+
r"""
|
| 289 |
+
Args:
|
| 290 |
+
x(Tensor): Shape [B, L, C]
|
| 291 |
+
e(Tensor): Shape [B, 6, C]
|
| 292 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 293 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 294 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 295 |
+
"""
|
| 296 |
+
assert e.dtype == torch.float32
|
| 297 |
+
with amp.autocast(dtype=torch.float32):
|
| 298 |
+
e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
|
| 299 |
+
assert e[0].dtype == torch.float32
|
| 300 |
+
|
| 301 |
+
# self-attention
|
| 302 |
+
y = self.self_attn(
|
| 303 |
+
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
| 304 |
+
freqs)
|
| 305 |
+
with amp.autocast(dtype=torch.float32):
|
| 306 |
+
x = x + y * e[2]
|
| 307 |
+
|
| 308 |
+
# cross-attention & ffn function
|
| 309 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 310 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 311 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 312 |
+
with amp.autocast(dtype=torch.float32):
|
| 313 |
+
x = x + y * e[5]
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class Head(nn.Module):
|
| 321 |
+
|
| 322 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.dim = dim
|
| 325 |
+
self.out_dim = out_dim
|
| 326 |
+
self.patch_size = patch_size
|
| 327 |
+
self.eps = eps
|
| 328 |
+
|
| 329 |
+
# layers
|
| 330 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 331 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 332 |
+
self.head = nn.Linear(dim, out_dim)
|
| 333 |
+
|
| 334 |
+
# modulation
|
| 335 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 336 |
+
|
| 337 |
+
def forward(self, x, e):
|
| 338 |
+
r"""
|
| 339 |
+
Args:
|
| 340 |
+
x(Tensor): Shape [B, L1, C]
|
| 341 |
+
e(Tensor): Shape [B, C]
|
| 342 |
+
"""
|
| 343 |
+
assert e.dtype == torch.float32
|
| 344 |
+
with amp.autocast(dtype=torch.float32):
|
| 345 |
+
e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
| 346 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 347 |
+
return x
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class MLPProj(torch.nn.Module):
|
| 351 |
+
|
| 352 |
+
def __init__(self, in_dim, out_dim, flf_pos_emb=False):
|
| 353 |
+
super().__init__()
|
| 354 |
+
|
| 355 |
+
self.proj = torch.nn.Sequential(
|
| 356 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 357 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 358 |
+
torch.nn.LayerNorm(out_dim))
|
| 359 |
+
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
| 360 |
+
self.emb_pos = nn.Parameter(
|
| 361 |
+
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
| 362 |
+
|
| 363 |
+
def forward(self, image_embeds):
|
| 364 |
+
if hasattr(self, 'emb_pos'):
|
| 365 |
+
bs, n, d = image_embeds.shape
|
| 366 |
+
image_embeds = image_embeds.view(-1, 2 * n, d)
|
| 367 |
+
image_embeds = image_embeds + self.emb_pos
|
| 368 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 369 |
+
return clip_extra_context_tokens
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 373 |
+
r"""
|
| 374 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
ignore_for_config = [
|
| 378 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 379 |
+
]
|
| 380 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 381 |
+
|
| 382 |
+
@register_to_config
|
| 383 |
+
def __init__(self,
|
| 384 |
+
model_type='t2v',
|
| 385 |
+
patch_size=(1, 2, 2),
|
| 386 |
+
text_len=512,
|
| 387 |
+
in_dim=16,
|
| 388 |
+
dim=2048,
|
| 389 |
+
ffn_dim=8192,
|
| 390 |
+
freq_dim=256,
|
| 391 |
+
text_dim=4096,
|
| 392 |
+
out_dim=16,
|
| 393 |
+
num_heads=16,
|
| 394 |
+
num_layers=32,
|
| 395 |
+
window_size=(-1, -1),
|
| 396 |
+
qk_norm=True,
|
| 397 |
+
cross_attn_norm=True,
|
| 398 |
+
eps=1e-6):
|
| 399 |
+
r"""
|
| 400 |
+
Initialize the diffusion model backbone.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 404 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
| 405 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 406 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 407 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 408 |
+
Fixed length for text embeddings
|
| 409 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 410 |
+
Input video channels (C_in)
|
| 411 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 412 |
+
Hidden dimension of the transformer
|
| 413 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 414 |
+
Intermediate dimension in feed-forward network
|
| 415 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 416 |
+
Dimension for sinusoidal time embeddings
|
| 417 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 418 |
+
Input dimension for text embeddings
|
| 419 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 420 |
+
Output video channels (C_out)
|
| 421 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 422 |
+
Number of attention heads
|
| 423 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 424 |
+
Number of transformer blocks
|
| 425 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 426 |
+
Window size for local attention (-1 indicates global attention)
|
| 427 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 428 |
+
Enable query/key normalization
|
| 429 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 430 |
+
Enable cross-attention normalization
|
| 431 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 432 |
+
Epsilon value for normalization layers
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
super().__init__()
|
| 436 |
+
|
| 437 |
+
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
| 438 |
+
self.model_type = model_type
|
| 439 |
+
|
| 440 |
+
self.patch_size = patch_size
|
| 441 |
+
self.text_len = text_len
|
| 442 |
+
self.in_dim = in_dim
|
| 443 |
+
self.dim = dim
|
| 444 |
+
self.ffn_dim = ffn_dim
|
| 445 |
+
self.freq_dim = freq_dim
|
| 446 |
+
self.text_dim = text_dim
|
| 447 |
+
self.out_dim = out_dim
|
| 448 |
+
self.num_heads = num_heads
|
| 449 |
+
self.num_layers = num_layers
|
| 450 |
+
self.window_size = window_size
|
| 451 |
+
self.qk_norm = qk_norm
|
| 452 |
+
self.cross_attn_norm = cross_attn_norm
|
| 453 |
+
self.eps = eps
|
| 454 |
+
|
| 455 |
+
# embeddings
|
| 456 |
+
self.patch_embedding = nn.Conv3d(
|
| 457 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 458 |
+
self.text_embedding = nn.Sequential(
|
| 459 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 460 |
+
nn.Linear(dim, dim))
|
| 461 |
+
|
| 462 |
+
self.time_embedding = nn.Sequential(
|
| 463 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 464 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 465 |
+
|
| 466 |
+
# blocks
|
| 467 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 468 |
+
self.blocks = nn.ModuleList([
|
| 469 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 470 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 471 |
+
for _ in range(num_layers)
|
| 472 |
+
])
|
| 473 |
+
|
| 474 |
+
# head
|
| 475 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 476 |
+
|
| 477 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 478 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 479 |
+
d = dim // num_heads
|
| 480 |
+
self.freqs = torch.cat([
|
| 481 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 482 |
+
rope_params(1024, 2 * (d // 6)),
|
| 483 |
+
rope_params(1024, 2 * (d // 6))
|
| 484 |
+
],
|
| 485 |
+
dim=1)
|
| 486 |
+
|
| 487 |
+
if model_type == 'i2v' or model_type == 'flf2v':
|
| 488 |
+
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
|
| 489 |
+
|
| 490 |
+
# initialize weights
|
| 491 |
+
self.init_weights()
|
| 492 |
+
|
| 493 |
+
def forward(
|
| 494 |
+
self,
|
| 495 |
+
x,
|
| 496 |
+
t,
|
| 497 |
+
context,
|
| 498 |
+
seq_len,
|
| 499 |
+
clip_fea=None,
|
| 500 |
+
y=None,
|
| 501 |
+
):
|
| 502 |
+
r"""
|
| 503 |
+
Forward pass through the diffusion model
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
x (List[Tensor]):
|
| 507 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 508 |
+
t (Tensor):
|
| 509 |
+
Diffusion timesteps tensor of shape [B]
|
| 510 |
+
context (List[Tensor]):
|
| 511 |
+
List of text embeddings each with shape [L, C]
|
| 512 |
+
seq_len (`int`):
|
| 513 |
+
Maximum sequence length for positional encoding
|
| 514 |
+
clip_fea (Tensor, *optional*):
|
| 515 |
+
CLIP image features for image-to-video mode or first-last-frame-to-video mode
|
| 516 |
+
y (List[Tensor], *optional*):
|
| 517 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
List[Tensor]:
|
| 521 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 522 |
+
"""
|
| 523 |
+
if self.model_type == 'i2v' or self.model_type == 'flf2v':
|
| 524 |
+
assert clip_fea is not None and y is not None
|
| 525 |
+
# params
|
| 526 |
+
device = self.patch_embedding.weight.device
|
| 527 |
+
if self.freqs.device != device:
|
| 528 |
+
self.freqs = self.freqs.to(device)
|
| 529 |
+
|
| 530 |
+
if y is not None:
|
| 531 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 532 |
+
|
| 533 |
+
# embeddings
|
| 534 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 535 |
+
grid_sizes = torch.stack(
|
| 536 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 537 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 538 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 539 |
+
assert seq_lens.max() <= seq_len
|
| 540 |
+
x = torch.cat([
|
| 541 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 542 |
+
dim=1) for u in x
|
| 543 |
+
])
|
| 544 |
+
|
| 545 |
+
# time embeddings
|
| 546 |
+
with amp.autocast(dtype=torch.float32):
|
| 547 |
+
e = self.time_embedding(
|
| 548 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 549 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 550 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 551 |
+
|
| 552 |
+
# context
|
| 553 |
+
context_lens = None
|
| 554 |
+
context = self.text_embedding(
|
| 555 |
+
torch.stack([
|
| 556 |
+
torch.cat(
|
| 557 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 558 |
+
for u in context
|
| 559 |
+
]))
|
| 560 |
+
|
| 561 |
+
if clip_fea is not None:
|
| 562 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
|
| 563 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 564 |
+
|
| 565 |
+
# arguments
|
| 566 |
+
kwargs = dict(
|
| 567 |
+
e=e0,
|
| 568 |
+
seq_lens=seq_lens,
|
| 569 |
+
grid_sizes=grid_sizes,
|
| 570 |
+
freqs=self.freqs,
|
| 571 |
+
context=context,
|
| 572 |
+
context_lens=context_lens)
|
| 573 |
+
|
| 574 |
+
for block in self.blocks:
|
| 575 |
+
x = block(x, **kwargs)
|
| 576 |
+
|
| 577 |
+
# head
|
| 578 |
+
x = self.head(x, e)
|
| 579 |
+
|
| 580 |
+
# unpatchify
|
| 581 |
+
x = self.unpatchify(x, grid_sizes)
|
| 582 |
+
return [u.float() for u in x]
|
| 583 |
+
|
| 584 |
+
def unpatchify(self, x, grid_sizes):
|
| 585 |
+
r"""
|
| 586 |
+
Reconstruct video tensors from patch embeddings.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
x (List[Tensor]):
|
| 590 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 591 |
+
grid_sizes (Tensor):
|
| 592 |
+
Original spatial-temporal grid dimensions before patching,
|
| 593 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 594 |
+
|
| 595 |
+
Returns:
|
| 596 |
+
List[Tensor]:
|
| 597 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
c = self.out_dim
|
| 601 |
+
out = []
|
| 602 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 603 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 604 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 605 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 606 |
+
out.append(u)
|
| 607 |
+
return out
|
| 608 |
+
|
| 609 |
+
def init_weights(self):
|
| 610 |
+
r"""
|
| 611 |
+
Initialize model parameters using Xavier initialization.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
# basic init
|
| 615 |
+
for m in self.modules():
|
| 616 |
+
if isinstance(m, nn.Linear):
|
| 617 |
+
nn.init.xavier_uniform_(m.weight)
|
| 618 |
+
if m.bias is not None:
|
| 619 |
+
nn.init.zeros_(m.bias)
|
| 620 |
+
|
| 621 |
+
# init embeddings
|
| 622 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 623 |
+
for m in self.text_embedding.modules():
|
| 624 |
+
if isinstance(m, nn.Linear):
|
| 625 |
+
nn.init.normal_(m.weight, std=.02)
|
| 626 |
+
for m in self.time_embedding.modules():
|
| 627 |
+
if isinstance(m, nn.Linear):
|
| 628 |
+
nn.init.normal_(m.weight, std=.02)
|
| 629 |
+
|
| 630 |
+
# init output layer
|
| 631 |
+
nn.init.zeros_(self.head.head.weight)
|
wan/modules/multitalk_model.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.cuda.amp as amp
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from diffusers import ModelMixin
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
|
| 14 |
+
from .attention import flash_attention, SingleStreamMutiAttention
|
| 15 |
+
from ..utils.multitalk_utils import get_attn_map_with_target
|
| 16 |
+
import logging
|
| 17 |
+
try:
|
| 18 |
+
from sageattention import sageattn
|
| 19 |
+
USE_SAGEATTN = True
|
| 20 |
+
logging.info("Using sageattn")
|
| 21 |
+
except:
|
| 22 |
+
USE_SAGEATTN = False
|
| 23 |
+
|
| 24 |
+
__all__ = ['WanModel']
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 29 |
+
# preprocess
|
| 30 |
+
assert dim % 2 == 0
|
| 31 |
+
half = dim // 2
|
| 32 |
+
position = position.type(torch.float64)
|
| 33 |
+
|
| 34 |
+
# calculation
|
| 35 |
+
sinusoid = torch.outer(
|
| 36 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 37 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@amp.autocast(enabled=False)
|
| 42 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 43 |
+
|
| 44 |
+
assert dim % 2 == 0
|
| 45 |
+
freqs = torch.outer(
|
| 46 |
+
torch.arange(max_seq_len),
|
| 47 |
+
1.0 / torch.pow(theta,
|
| 48 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 49 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 50 |
+
return freqs
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@amp.autocast(enabled=False)
|
| 54 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 55 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 56 |
+
|
| 57 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 58 |
+
|
| 59 |
+
output = []
|
| 60 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 61 |
+
seq_len = f * h * w
|
| 62 |
+
|
| 63 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 64 |
+
s, n, -1, 2))
|
| 65 |
+
freqs_i = torch.cat([
|
| 66 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 67 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 68 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 69 |
+
],
|
| 70 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 71 |
+
freqs_i = freqs_i.to(device=x_i.device)
|
| 72 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 73 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 74 |
+
|
| 75 |
+
output.append(x_i)
|
| 76 |
+
return torch.stack(output).float()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class WanRMSNorm(nn.Module):
|
| 80 |
+
|
| 81 |
+
def __init__(self, dim, eps=1e-5):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.dim = dim
|
| 84 |
+
self.eps = eps
|
| 85 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
r"""
|
| 89 |
+
Args:
|
| 90 |
+
x(Tensor): Shape [B, L, C]
|
| 91 |
+
"""
|
| 92 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 93 |
+
|
| 94 |
+
def _norm(self, x):
|
| 95 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 99 |
+
|
| 100 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 101 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 102 |
+
|
| 103 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
origin_dtype = inputs.dtype
|
| 105 |
+
out = F.layer_norm(
|
| 106 |
+
inputs.float(),
|
| 107 |
+
self.normalized_shape,
|
| 108 |
+
None if self.weight is None else self.weight.float(),
|
| 109 |
+
None if self.bias is None else self.bias.float() ,
|
| 110 |
+
self.eps
|
| 111 |
+
).to(origin_dtype)
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class WanSelfAttention(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self,
|
| 118 |
+
dim,
|
| 119 |
+
num_heads,
|
| 120 |
+
window_size=(-1, -1),
|
| 121 |
+
qk_norm=True,
|
| 122 |
+
eps=1e-6):
|
| 123 |
+
assert dim % num_heads == 0
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.dim = dim
|
| 126 |
+
self.num_heads = num_heads
|
| 127 |
+
self.head_dim = dim // num_heads
|
| 128 |
+
self.window_size = window_size
|
| 129 |
+
self.qk_norm = qk_norm
|
| 130 |
+
self.eps = eps
|
| 131 |
+
|
| 132 |
+
# layers
|
| 133 |
+
self.q = nn.Linear(dim, dim)
|
| 134 |
+
self.k = nn.Linear(dim, dim)
|
| 135 |
+
self.v = nn.Linear(dim, dim)
|
| 136 |
+
self.o = nn.Linear(dim, dim)
|
| 137 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 138 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 139 |
+
|
| 140 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
|
| 141 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 142 |
+
|
| 143 |
+
# query, key, value function
|
| 144 |
+
def qkv_fn(x):
|
| 145 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 146 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 147 |
+
v = self.v(x).view(b, s, n, d)
|
| 148 |
+
return q, k, v
|
| 149 |
+
q, k, v = qkv_fn(x)
|
| 150 |
+
|
| 151 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 152 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 153 |
+
|
| 154 |
+
if USE_SAGEATTN:
|
| 155 |
+
x = sageattn(q.to(torch.bfloat16), k.to(torch.bfloat16), v, tensor_layout='NHD')
|
| 156 |
+
else:
|
| 157 |
+
x = flash_attention(
|
| 158 |
+
q=q,
|
| 159 |
+
k=k,
|
| 160 |
+
v=v,
|
| 161 |
+
k_lens=seq_lens,
|
| 162 |
+
window_size=self.window_size
|
| 163 |
+
).type_as(x)
|
| 164 |
+
|
| 165 |
+
# output
|
| 166 |
+
x = x.flatten(2)
|
| 167 |
+
x = self.o(x)
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0],
|
| 170 |
+
ref_target_masks=ref_target_masks)
|
| 171 |
+
|
| 172 |
+
return x, x_ref_attn_map
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 176 |
+
|
| 177 |
+
def __init__(self,
|
| 178 |
+
dim,
|
| 179 |
+
num_heads,
|
| 180 |
+
window_size=(-1, -1),
|
| 181 |
+
qk_norm=True,
|
| 182 |
+
eps=1e-6):
|
| 183 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 184 |
+
|
| 185 |
+
self.k_img = nn.Linear(dim, dim)
|
| 186 |
+
self.v_img = nn.Linear(dim, dim)
|
| 187 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 188 |
+
|
| 189 |
+
def forward(self, x, context, context_lens):
|
| 190 |
+
context_img = context[:, :257]
|
| 191 |
+
context = context[:, 257:]
|
| 192 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 193 |
+
|
| 194 |
+
# compute query, key, value
|
| 195 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 196 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 197 |
+
v = self.v(context).view(b, -1, n, d)
|
| 198 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
| 199 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
| 200 |
+
if USE_SAGEATTN:
|
| 201 |
+
img_x = sageattn(q, k_img, v_img, tensor_layout='NHD')
|
| 202 |
+
x = sageattn(q, k, v, tensor_layout='NHD')
|
| 203 |
+
else:
|
| 204 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| 205 |
+
# compute attention
|
| 206 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 207 |
+
|
| 208 |
+
# output
|
| 209 |
+
x = x.flatten(2)
|
| 210 |
+
img_x = img_x.flatten(2)
|
| 211 |
+
x = x + img_x
|
| 212 |
+
x = self.o(x)
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class WanAttentionBlock(nn.Module):
|
| 217 |
+
|
| 218 |
+
def __init__(self,
|
| 219 |
+
cross_attn_type,
|
| 220 |
+
dim,
|
| 221 |
+
ffn_dim,
|
| 222 |
+
num_heads,
|
| 223 |
+
window_size=(-1, -1),
|
| 224 |
+
qk_norm=True,
|
| 225 |
+
cross_attn_norm=False,
|
| 226 |
+
eps=1e-6,
|
| 227 |
+
output_dim=768,
|
| 228 |
+
norm_input_visual=True,
|
| 229 |
+
class_range=24,
|
| 230 |
+
class_interval=4):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.dim = dim
|
| 233 |
+
self.ffn_dim = ffn_dim
|
| 234 |
+
self.num_heads = num_heads
|
| 235 |
+
self.window_size = window_size
|
| 236 |
+
self.qk_norm = qk_norm
|
| 237 |
+
self.cross_attn_norm = cross_attn_norm
|
| 238 |
+
self.eps = eps
|
| 239 |
+
|
| 240 |
+
# layers
|
| 241 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 242 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
| 243 |
+
self.norm3 = WanLayerNorm(
|
| 244 |
+
dim, eps,
|
| 245 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 246 |
+
self.cross_attn = WanI2VCrossAttention(dim,
|
| 247 |
+
num_heads,
|
| 248 |
+
(-1, -1),
|
| 249 |
+
qk_norm,
|
| 250 |
+
eps)
|
| 251 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 252 |
+
self.ffn = nn.Sequential(
|
| 253 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 254 |
+
nn.Linear(ffn_dim, dim))
|
| 255 |
+
|
| 256 |
+
# modulation
|
| 257 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 258 |
+
|
| 259 |
+
# init audio module
|
| 260 |
+
self.audio_cross_attn = SingleStreamMutiAttention(
|
| 261 |
+
dim=dim,
|
| 262 |
+
encoder_hidden_states_dim=output_dim,
|
| 263 |
+
num_heads=num_heads,
|
| 264 |
+
qk_norm=False,
|
| 265 |
+
qkv_bias=True,
|
| 266 |
+
eps=eps,
|
| 267 |
+
norm_layer=WanRMSNorm,
|
| 268 |
+
class_range=class_range,
|
| 269 |
+
class_interval=class_interval
|
| 270 |
+
)
|
| 271 |
+
self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def forward(
|
| 275 |
+
self,
|
| 276 |
+
x,
|
| 277 |
+
e,
|
| 278 |
+
seq_lens,
|
| 279 |
+
grid_sizes,
|
| 280 |
+
freqs,
|
| 281 |
+
context,
|
| 282 |
+
context_lens,
|
| 283 |
+
audio_embedding=None,
|
| 284 |
+
ref_target_masks=None,
|
| 285 |
+
human_num=None,
|
| 286 |
+
):
|
| 287 |
+
|
| 288 |
+
dtype = x.dtype
|
| 289 |
+
assert e.dtype == torch.float32
|
| 290 |
+
with amp.autocast(dtype=torch.float32):
|
| 291 |
+
e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
|
| 292 |
+
assert e[0].dtype == torch.float32
|
| 293 |
+
|
| 294 |
+
# self-attention
|
| 295 |
+
y, x_ref_attn_map = self.self_attn(
|
| 296 |
+
(self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
|
| 297 |
+
freqs, ref_target_masks=ref_target_masks)
|
| 298 |
+
with amp.autocast(dtype=torch.float32):
|
| 299 |
+
x = x + y * e[2]
|
| 300 |
+
|
| 301 |
+
x = x.to(dtype)
|
| 302 |
+
|
| 303 |
+
# cross-attention of text
|
| 304 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 305 |
+
|
| 306 |
+
# cross attn of audio
|
| 307 |
+
x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
|
| 308 |
+
shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
|
| 309 |
+
x = x + x_a
|
| 310 |
+
|
| 311 |
+
y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
|
| 312 |
+
with amp.autocast(dtype=torch.float32):
|
| 313 |
+
x = x + y * e[5]
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
x = x.to(dtype)
|
| 317 |
+
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class Head(nn.Module):
|
| 322 |
+
|
| 323 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.dim = dim
|
| 326 |
+
self.out_dim = out_dim
|
| 327 |
+
self.patch_size = patch_size
|
| 328 |
+
self.eps = eps
|
| 329 |
+
|
| 330 |
+
# layers
|
| 331 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 332 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 333 |
+
self.head = nn.Linear(dim, out_dim)
|
| 334 |
+
|
| 335 |
+
# modulation
|
| 336 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 337 |
+
|
| 338 |
+
def forward(self, x, e):
|
| 339 |
+
r"""
|
| 340 |
+
Args:
|
| 341 |
+
x(Tensor): Shape [B, L1, C]
|
| 342 |
+
e(Tensor): Shape [B, C]
|
| 343 |
+
"""
|
| 344 |
+
assert e.dtype == torch.float32
|
| 345 |
+
with amp.autocast(dtype=torch.float32):
|
| 346 |
+
e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
| 347 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 348 |
+
return x
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MLPProj(torch.nn.Module):
|
| 352 |
+
|
| 353 |
+
def __init__(self, in_dim, out_dim):
|
| 354 |
+
super().__init__()
|
| 355 |
+
|
| 356 |
+
self.proj = torch.nn.Sequential(
|
| 357 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 358 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 359 |
+
torch.nn.LayerNorm(out_dim))
|
| 360 |
+
|
| 361 |
+
def forward(self, image_embeds):
|
| 362 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 363 |
+
return clip_extra_context_tokens
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class AudioProjModel(ModelMixin, ConfigMixin):
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
seq_len=5,
|
| 370 |
+
seq_len_vf=12,
|
| 371 |
+
blocks=12,
|
| 372 |
+
channels=768,
|
| 373 |
+
intermediate_dim=512,
|
| 374 |
+
output_dim=768,
|
| 375 |
+
context_tokens=32,
|
| 376 |
+
norm_output_audio=False,
|
| 377 |
+
):
|
| 378 |
+
super().__init__()
|
| 379 |
+
|
| 380 |
+
self.seq_len = seq_len
|
| 381 |
+
self.blocks = blocks
|
| 382 |
+
self.channels = channels
|
| 383 |
+
self.input_dim = seq_len * blocks * channels
|
| 384 |
+
self.input_dim_vf = seq_len_vf * blocks * channels
|
| 385 |
+
self.intermediate_dim = intermediate_dim
|
| 386 |
+
self.context_tokens = context_tokens
|
| 387 |
+
self.output_dim = output_dim
|
| 388 |
+
|
| 389 |
+
# define multiple linear layers
|
| 390 |
+
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
| 391 |
+
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
| 392 |
+
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
| 393 |
+
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
| 394 |
+
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
| 395 |
+
|
| 396 |
+
def forward(self, audio_embeds, audio_embeds_vf):
|
| 397 |
+
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
| 398 |
+
B, _, _, S, C = audio_embeds.shape
|
| 399 |
+
|
| 400 |
+
# process audio of first frame
|
| 401 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
| 402 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
| 403 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
| 404 |
+
|
| 405 |
+
# process audio of latter frame
|
| 406 |
+
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
| 407 |
+
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
| 408 |
+
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
| 409 |
+
|
| 410 |
+
# first projection
|
| 411 |
+
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
| 412 |
+
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
| 413 |
+
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
| 414 |
+
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
| 415 |
+
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
| 416 |
+
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
| 417 |
+
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
| 418 |
+
|
| 419 |
+
# second projection
|
| 420 |
+
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
| 421 |
+
|
| 422 |
+
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
| 423 |
+
|
| 424 |
+
# normalization and reshape
|
| 425 |
+
with amp.autocast(dtype=torch.float32):
|
| 426 |
+
context_tokens = self.norm(context_tokens)
|
| 427 |
+
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
| 428 |
+
|
| 429 |
+
return context_tokens
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 433 |
+
r"""
|
| 434 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
ignore_for_config = [
|
| 438 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 439 |
+
]
|
| 440 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 441 |
+
|
| 442 |
+
@register_to_config
|
| 443 |
+
def __init__(self,
|
| 444 |
+
model_type='i2v',
|
| 445 |
+
patch_size=(1, 2, 2),
|
| 446 |
+
text_len=512,
|
| 447 |
+
in_dim=16,
|
| 448 |
+
dim=2048,
|
| 449 |
+
ffn_dim=8192,
|
| 450 |
+
freq_dim=256,
|
| 451 |
+
text_dim=4096,
|
| 452 |
+
out_dim=16,
|
| 453 |
+
num_heads=16,
|
| 454 |
+
num_layers=32,
|
| 455 |
+
window_size=(-1, -1),
|
| 456 |
+
qk_norm=True,
|
| 457 |
+
cross_attn_norm=True,
|
| 458 |
+
eps=1e-6,
|
| 459 |
+
# audio params
|
| 460 |
+
audio_window=5,
|
| 461 |
+
intermediate_dim=512,
|
| 462 |
+
output_dim=768,
|
| 463 |
+
context_tokens=32,
|
| 464 |
+
vae_scale=4, # vae timedownsample scale
|
| 465 |
+
|
| 466 |
+
norm_input_visual=True,
|
| 467 |
+
norm_output_audio=True,
|
| 468 |
+
weight_init=True):
|
| 469 |
+
super().__init__()
|
| 470 |
+
|
| 471 |
+
assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
|
| 472 |
+
self.model_type = model_type
|
| 473 |
+
|
| 474 |
+
self.patch_size = patch_size
|
| 475 |
+
self.text_len = text_len
|
| 476 |
+
self.in_dim = in_dim
|
| 477 |
+
self.dim = dim
|
| 478 |
+
self.ffn_dim = ffn_dim
|
| 479 |
+
self.freq_dim = freq_dim
|
| 480 |
+
self.text_dim = text_dim
|
| 481 |
+
self.out_dim = out_dim
|
| 482 |
+
self.num_heads = num_heads
|
| 483 |
+
self.num_layers = num_layers
|
| 484 |
+
self.window_size = window_size
|
| 485 |
+
self.qk_norm = qk_norm
|
| 486 |
+
self.cross_attn_norm = cross_attn_norm
|
| 487 |
+
self.eps = eps
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
self.norm_output_audio = norm_output_audio
|
| 491 |
+
self.audio_window = audio_window
|
| 492 |
+
self.intermediate_dim = intermediate_dim
|
| 493 |
+
self.vae_scale = vae_scale
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# embeddings
|
| 497 |
+
self.patch_embedding = nn.Conv3d(
|
| 498 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 499 |
+
self.text_embedding = nn.Sequential(
|
| 500 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 501 |
+
nn.Linear(dim, dim))
|
| 502 |
+
|
| 503 |
+
self.time_embedding = nn.Sequential(
|
| 504 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 505 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 506 |
+
|
| 507 |
+
# blocks
|
| 508 |
+
cross_attn_type = 'i2v_cross_attn'
|
| 509 |
+
self.blocks = nn.ModuleList([
|
| 510 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 511 |
+
window_size, qk_norm, cross_attn_norm, eps,
|
| 512 |
+
output_dim=output_dim, norm_input_visual=norm_input_visual)
|
| 513 |
+
for _ in range(num_layers)
|
| 514 |
+
])
|
| 515 |
+
|
| 516 |
+
# head
|
| 517 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 518 |
+
|
| 519 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 520 |
+
d = dim // num_heads
|
| 521 |
+
self.freqs = torch.cat([
|
| 522 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 523 |
+
rope_params(1024, 2 * (d // 6)),
|
| 524 |
+
rope_params(1024, 2 * (d // 6))
|
| 525 |
+
],
|
| 526 |
+
dim=1)
|
| 527 |
+
|
| 528 |
+
if model_type == 'i2v':
|
| 529 |
+
self.img_emb = MLPProj(1280, dim)
|
| 530 |
+
else:
|
| 531 |
+
raise NotImplementedError('Not supported model type.')
|
| 532 |
+
|
| 533 |
+
# init audio adapter
|
| 534 |
+
self.audio_proj = AudioProjModel(
|
| 535 |
+
seq_len=audio_window,
|
| 536 |
+
seq_len_vf=audio_window+vae_scale-1,
|
| 537 |
+
intermediate_dim=intermediate_dim,
|
| 538 |
+
output_dim=output_dim,
|
| 539 |
+
context_tokens=context_tokens,
|
| 540 |
+
norm_output_audio=norm_output_audio,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# initialize weights
|
| 545 |
+
if weight_init:
|
| 546 |
+
self.init_weights()
|
| 547 |
+
|
| 548 |
+
def init_freqs(self):
|
| 549 |
+
d = self.dim // self.num_heads
|
| 550 |
+
self.freqs = torch.cat([
|
| 551 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 552 |
+
rope_params(1024, 2 * (d // 6)),
|
| 553 |
+
rope_params(1024, 2 * (d // 6))
|
| 554 |
+
],
|
| 555 |
+
dim=1)
|
| 556 |
+
|
| 557 |
+
def teacache_init(
|
| 558 |
+
self,
|
| 559 |
+
use_ret_steps=True,
|
| 560 |
+
teacache_thresh=0.2,
|
| 561 |
+
sample_steps=40,
|
| 562 |
+
model_scale='infinitetalk-480',
|
| 563 |
+
):
|
| 564 |
+
print("teacache_init")
|
| 565 |
+
self.enable_teacache = True
|
| 566 |
+
|
| 567 |
+
self.__class__.cnt = 0
|
| 568 |
+
self.__class__.num_steps = sample_steps*3
|
| 569 |
+
self.__class__.teacache_thresh = teacache_thresh
|
| 570 |
+
self.__class__.accumulated_rel_l1_distance_even = 0
|
| 571 |
+
self.__class__.accumulated_rel_l1_distance_odd = 0
|
| 572 |
+
self.__class__.previous_e0_even = None
|
| 573 |
+
self.__class__.previous_e0_odd = None
|
| 574 |
+
self.__class__.previous_residual_even = None
|
| 575 |
+
self.__class__.previous_residual_odd = None
|
| 576 |
+
self.__class__.use_ret_steps = use_ret_steps
|
| 577 |
+
|
| 578 |
+
if use_ret_steps:
|
| 579 |
+
if model_scale == 'infinitetalk-480':
|
| 580 |
+
self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
| 581 |
+
if model_scale == 'infinitetalk-720':
|
| 582 |
+
self.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
| 583 |
+
self.__class__.ret_steps = 5*3
|
| 584 |
+
self.__class__.cutoff_steps = sample_steps*3
|
| 585 |
+
else:
|
| 586 |
+
if model_scale == 'infinitetalk-480':
|
| 587 |
+
self.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
| 588 |
+
|
| 589 |
+
if model_scale == 'infinitetalk-720':
|
| 590 |
+
self.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
| 591 |
+
self.__class__.ret_steps = 1*3
|
| 592 |
+
self.__class__.cutoff_steps = sample_steps*3 - 3
|
| 593 |
+
print("teacache_init done")
|
| 594 |
+
|
| 595 |
+
def disable_teacache(self):
|
| 596 |
+
self.enable_teacache = False
|
| 597 |
+
|
| 598 |
+
def forward(
|
| 599 |
+
self,
|
| 600 |
+
x,
|
| 601 |
+
t,
|
| 602 |
+
context,
|
| 603 |
+
seq_len,
|
| 604 |
+
clip_fea=None,
|
| 605 |
+
y=None,
|
| 606 |
+
audio=None,
|
| 607 |
+
ref_target_masks=None,
|
| 608 |
+
):
|
| 609 |
+
assert clip_fea is not None and y is not None
|
| 610 |
+
|
| 611 |
+
_, T, H, W = x[0].shape
|
| 612 |
+
N_t = T // self.patch_size[0]
|
| 613 |
+
N_h = H // self.patch_size[1]
|
| 614 |
+
N_w = W // self.patch_size[2]
|
| 615 |
+
|
| 616 |
+
if y is not None:
|
| 617 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 618 |
+
x[0] = x[0].to(context[0].dtype)
|
| 619 |
+
|
| 620 |
+
# embeddings
|
| 621 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 622 |
+
grid_sizes = torch.stack(
|
| 623 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 624 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 625 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 626 |
+
assert seq_lens.max() <= seq_len
|
| 627 |
+
x = torch.cat([
|
| 628 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 629 |
+
dim=1) for u in x
|
| 630 |
+
])
|
| 631 |
+
|
| 632 |
+
# time embeddings
|
| 633 |
+
with amp.autocast(dtype=torch.float32):
|
| 634 |
+
e = self.time_embedding(
|
| 635 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 636 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 637 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 638 |
+
|
| 639 |
+
# text embedding
|
| 640 |
+
context_lens = None
|
| 641 |
+
context = self.text_embedding(
|
| 642 |
+
torch.stack([
|
| 643 |
+
torch.cat(
|
| 644 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 645 |
+
for u in context
|
| 646 |
+
]))
|
| 647 |
+
|
| 648 |
+
# clip embedding
|
| 649 |
+
if clip_fea is not None:
|
| 650 |
+
context_clip = self.img_emb(clip_fea)
|
| 651 |
+
context = torch.concat([context_clip, context], dim=1).to(x.dtype)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
audio_cond = audio.to(device=x.device, dtype=x.dtype)
|
| 655 |
+
first_frame_audio_emb_s = audio_cond[:, :1, ...]
|
| 656 |
+
latter_frame_audio_emb = audio_cond[:, 1:, ...]
|
| 657 |
+
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
|
| 658 |
+
middle_index = self.audio_window // 2
|
| 659 |
+
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
| 660 |
+
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 661 |
+
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
| 662 |
+
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 663 |
+
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
| 664 |
+
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
| 665 |
+
latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
| 666 |
+
audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
| 667 |
+
human_num = len(audio_embedding)
|
| 668 |
+
audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
# convert ref_target_masks to token_ref_target_masks
|
| 672 |
+
if ref_target_masks is not None:
|
| 673 |
+
ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32)
|
| 674 |
+
token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
|
| 675 |
+
token_ref_target_masks = token_ref_target_masks.squeeze(0)
|
| 676 |
+
token_ref_target_masks = (token_ref_target_masks > 0)
|
| 677 |
+
token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
|
| 678 |
+
token_ref_target_masks = token_ref_target_masks.to(x.dtype)
|
| 679 |
+
|
| 680 |
+
# teacache
|
| 681 |
+
if self.enable_teacache:
|
| 682 |
+
modulated_inp = e0 if self.use_ret_steps else e
|
| 683 |
+
if self.cnt%3==0: # cond
|
| 684 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 685 |
+
should_calc_cond = True
|
| 686 |
+
self.accumulated_rel_l1_distance_cond = 0
|
| 687 |
+
else:
|
| 688 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 689 |
+
self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
|
| 690 |
+
if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
|
| 691 |
+
should_calc_cond = False
|
| 692 |
+
else:
|
| 693 |
+
should_calc_cond = True
|
| 694 |
+
self.accumulated_rel_l1_distance_cond = 0
|
| 695 |
+
self.previous_e0_cond = modulated_inp.clone()
|
| 696 |
+
elif self.cnt%3==1: # drop_text
|
| 697 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 698 |
+
should_calc_drop_text = True
|
| 699 |
+
self.accumulated_rel_l1_distance_drop_text = 0
|
| 700 |
+
else:
|
| 701 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 702 |
+
self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
|
| 703 |
+
if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
|
| 704 |
+
should_calc_drop_text = False
|
| 705 |
+
else:
|
| 706 |
+
should_calc_drop_text = True
|
| 707 |
+
self.accumulated_rel_l1_distance_drop_text = 0
|
| 708 |
+
self.previous_e0_drop_text = modulated_inp.clone()
|
| 709 |
+
else: # uncond
|
| 710 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
| 711 |
+
should_calc_uncond = True
|
| 712 |
+
self.accumulated_rel_l1_distance_uncond = 0
|
| 713 |
+
else:
|
| 714 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 715 |
+
self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
|
| 716 |
+
if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
|
| 717 |
+
should_calc_uncond = False
|
| 718 |
+
else:
|
| 719 |
+
should_calc_uncond = True
|
| 720 |
+
self.accumulated_rel_l1_distance_uncond = 0
|
| 721 |
+
self.previous_e0_uncond = modulated_inp.clone()
|
| 722 |
+
|
| 723 |
+
# arguments
|
| 724 |
+
kwargs = dict(
|
| 725 |
+
e=e0,
|
| 726 |
+
seq_lens=seq_lens,
|
| 727 |
+
grid_sizes=grid_sizes,
|
| 728 |
+
freqs=self.freqs,
|
| 729 |
+
context=context,
|
| 730 |
+
context_lens=context_lens,
|
| 731 |
+
audio_embedding=audio_embedding,
|
| 732 |
+
ref_target_masks=token_ref_target_masks,
|
| 733 |
+
human_num=human_num,
|
| 734 |
+
)
|
| 735 |
+
if self.enable_teacache:
|
| 736 |
+
if self.cnt%3==0:
|
| 737 |
+
if not should_calc_cond:
|
| 738 |
+
x += self.previous_residual_cond
|
| 739 |
+
else:
|
| 740 |
+
ori_x = x.clone()
|
| 741 |
+
for block in self.blocks:
|
| 742 |
+
x = block(x, **kwargs)
|
| 743 |
+
self.previous_residual_cond = x - ori_x
|
| 744 |
+
elif self.cnt%3==1:
|
| 745 |
+
if not should_calc_drop_text:
|
| 746 |
+
x += self.previous_residual_drop_text
|
| 747 |
+
else:
|
| 748 |
+
ori_x = x.clone()
|
| 749 |
+
for block in self.blocks:
|
| 750 |
+
x = block(x, **kwargs)
|
| 751 |
+
self.previous_residual_drop_text = x - ori_x
|
| 752 |
+
else:
|
| 753 |
+
if not should_calc_uncond:
|
| 754 |
+
x += self.previous_residual_uncond
|
| 755 |
+
else:
|
| 756 |
+
ori_x = x.clone()
|
| 757 |
+
for block in self.blocks:
|
| 758 |
+
x = block(x, **kwargs)
|
| 759 |
+
self.previous_residual_uncond = x - ori_x
|
| 760 |
+
else:
|
| 761 |
+
for block in self.blocks:
|
| 762 |
+
x = block(x, **kwargs)
|
| 763 |
+
|
| 764 |
+
# head
|
| 765 |
+
x = self.head(x, e)
|
| 766 |
+
|
| 767 |
+
# unpatchify
|
| 768 |
+
x = self.unpatchify(x, grid_sizes)
|
| 769 |
+
if self.enable_teacache:
|
| 770 |
+
self.cnt += 1
|
| 771 |
+
if self.cnt >= self.num_steps:
|
| 772 |
+
self.cnt = 0
|
| 773 |
+
|
| 774 |
+
return torch.stack(x).float()
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def unpatchify(self, x, grid_sizes):
|
| 778 |
+
r"""
|
| 779 |
+
Reconstruct video tensors from patch embeddings.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
x (List[Tensor]):
|
| 783 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 784 |
+
grid_sizes (Tensor):
|
| 785 |
+
Original spatial-temporal grid dimensions before patching,
|
| 786 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
List[Tensor]:
|
| 790 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
c = self.out_dim
|
| 794 |
+
out = []
|
| 795 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 796 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 797 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 798 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 799 |
+
out.append(u)
|
| 800 |
+
return out
|
| 801 |
+
|
| 802 |
+
def init_weights(self):
|
| 803 |
+
r"""
|
| 804 |
+
Initialize model parameters using Xavier initialization.
|
| 805 |
+
"""
|
| 806 |
+
|
| 807 |
+
# basic init
|
| 808 |
+
for m in self.modules():
|
| 809 |
+
if isinstance(m, nn.Linear):
|
| 810 |
+
nn.init.xavier_uniform_(m.weight)
|
| 811 |
+
if m.bias is not None:
|
| 812 |
+
nn.init.zeros_(m.bias)
|
| 813 |
+
|
| 814 |
+
# init embeddings
|
| 815 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 816 |
+
for m in self.text_embedding.modules():
|
| 817 |
+
if isinstance(m, nn.Linear):
|
| 818 |
+
nn.init.normal_(m.weight, std=.02)
|
| 819 |
+
for m in self.time_embedding.modules():
|
| 820 |
+
if isinstance(m, nn.Linear):
|
| 821 |
+
nn.init.normal_(m.weight, std=.02)
|
| 822 |
+
|
| 823 |
+
# init output layer
|
| 824 |
+
nn.init.zeros_(self.head.head.weight)
|
wan/modules/t5.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.t5.modeling_t5
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
from optimum.quanto import quantize, freeze, qint8,requantize
|
| 14 |
+
|
| 15 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'T5Model',
|
| 19 |
+
'T5Encoder',
|
| 20 |
+
'T5Decoder',
|
| 21 |
+
'T5EncoderModel',
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def fp16_clamp(x):
|
| 26 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 27 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 28 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def init_weights(m):
|
| 33 |
+
if isinstance(m, T5LayerNorm):
|
| 34 |
+
nn.init.ones_(m.weight)
|
| 35 |
+
elif isinstance(m, T5Model):
|
| 36 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
| 37 |
+
elif isinstance(m, T5FeedForward):
|
| 38 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 39 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 40 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 41 |
+
elif isinstance(m, T5Attention):
|
| 42 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
| 43 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 44 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 45 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
| 46 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 47 |
+
nn.init.normal_(
|
| 48 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class GELU(nn.Module):
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
| 55 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class T5LayerNorm(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, dim, eps=1e-6):
|
| 61 |
+
super(T5LayerNorm, self).__init__()
|
| 62 |
+
self.dim = dim
|
| 63 |
+
self.eps = eps
|
| 64 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
| 68 |
+
self.eps)
|
| 69 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 70 |
+
x = x.type_as(self.weight)
|
| 71 |
+
return self.weight * x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class T5Attention(nn.Module):
|
| 75 |
+
|
| 76 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 77 |
+
assert dim_attn % num_heads == 0
|
| 78 |
+
super(T5Attention, self).__init__()
|
| 79 |
+
self.dim = dim
|
| 80 |
+
self.dim_attn = dim_attn
|
| 81 |
+
self.num_heads = num_heads
|
| 82 |
+
self.head_dim = dim_attn // num_heads
|
| 83 |
+
|
| 84 |
+
# layers
|
| 85 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 86 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 87 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 88 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 89 |
+
self.dropout = nn.Dropout(dropout)
|
| 90 |
+
|
| 91 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 92 |
+
"""
|
| 93 |
+
x: [B, L1, C].
|
| 94 |
+
context: [B, L2, C] or None.
|
| 95 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 96 |
+
"""
|
| 97 |
+
# check inputs
|
| 98 |
+
context = x if context is None else context
|
| 99 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 100 |
+
|
| 101 |
+
# compute query, key, value
|
| 102 |
+
q = self.q(x).view(b, -1, n, c)
|
| 103 |
+
k = self.k(context).view(b, -1, n, c)
|
| 104 |
+
v = self.v(context).view(b, -1, n, c)
|
| 105 |
+
|
| 106 |
+
# attention bias
|
| 107 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 108 |
+
if pos_bias is not None:
|
| 109 |
+
attn_bias += pos_bias
|
| 110 |
+
if mask is not None:
|
| 111 |
+
assert mask.ndim in [2, 3]
|
| 112 |
+
mask = mask.view(b, 1, 1,
|
| 113 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 114 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 115 |
+
|
| 116 |
+
# compute attention (T5 does not use scaling)
|
| 117 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
| 118 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 119 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
| 120 |
+
|
| 121 |
+
# output
|
| 122 |
+
x = x.reshape(b, -1, n * c)
|
| 123 |
+
x = self.o(x)
|
| 124 |
+
x = self.dropout(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class T5FeedForward(nn.Module):
|
| 129 |
+
|
| 130 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 131 |
+
super(T5FeedForward, self).__init__()
|
| 132 |
+
self.dim = dim
|
| 133 |
+
self.dim_ffn = dim_ffn
|
| 134 |
+
|
| 135 |
+
# layers
|
| 136 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 137 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 138 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 139 |
+
self.dropout = nn.Dropout(dropout)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
x = self.fc1(x) * self.gate(x)
|
| 143 |
+
x = self.dropout(x)
|
| 144 |
+
x = self.fc2(x)
|
| 145 |
+
x = self.dropout(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class T5SelfAttention(nn.Module):
|
| 150 |
+
|
| 151 |
+
def __init__(self,
|
| 152 |
+
dim,
|
| 153 |
+
dim_attn,
|
| 154 |
+
dim_ffn,
|
| 155 |
+
num_heads,
|
| 156 |
+
num_buckets,
|
| 157 |
+
shared_pos=True,
|
| 158 |
+
dropout=0.1):
|
| 159 |
+
super(T5SelfAttention, self).__init__()
|
| 160 |
+
self.dim = dim
|
| 161 |
+
self.dim_attn = dim_attn
|
| 162 |
+
self.dim_ffn = dim_ffn
|
| 163 |
+
self.num_heads = num_heads
|
| 164 |
+
self.num_buckets = num_buckets
|
| 165 |
+
self.shared_pos = shared_pos
|
| 166 |
+
|
| 167 |
+
# layers
|
| 168 |
+
self.norm1 = T5LayerNorm(dim)
|
| 169 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 170 |
+
self.norm2 = T5LayerNorm(dim)
|
| 171 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 172 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 173 |
+
num_buckets, num_heads, bidirectional=True)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 176 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 177 |
+
x.size(1), x.size(1))
|
| 178 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 179 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class T5CrossAttention(nn.Module):
|
| 184 |
+
|
| 185 |
+
def __init__(self,
|
| 186 |
+
dim,
|
| 187 |
+
dim_attn,
|
| 188 |
+
dim_ffn,
|
| 189 |
+
num_heads,
|
| 190 |
+
num_buckets,
|
| 191 |
+
shared_pos=True,
|
| 192 |
+
dropout=0.1):
|
| 193 |
+
super(T5CrossAttention, self).__init__()
|
| 194 |
+
self.dim = dim
|
| 195 |
+
self.dim_attn = dim_attn
|
| 196 |
+
self.dim_ffn = dim_ffn
|
| 197 |
+
self.num_heads = num_heads
|
| 198 |
+
self.num_buckets = num_buckets
|
| 199 |
+
self.shared_pos = shared_pos
|
| 200 |
+
|
| 201 |
+
# layers
|
| 202 |
+
self.norm1 = T5LayerNorm(dim)
|
| 203 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 204 |
+
self.norm2 = T5LayerNorm(dim)
|
| 205 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 206 |
+
self.norm3 = T5LayerNorm(dim)
|
| 207 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 208 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 209 |
+
num_buckets, num_heads, bidirectional=False)
|
| 210 |
+
|
| 211 |
+
def forward(self,
|
| 212 |
+
x,
|
| 213 |
+
mask=None,
|
| 214 |
+
encoder_states=None,
|
| 215 |
+
encoder_mask=None,
|
| 216 |
+
pos_bias=None):
|
| 217 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 218 |
+
x.size(1), x.size(1))
|
| 219 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 220 |
+
x = fp16_clamp(x + self.cross_attn(
|
| 221 |
+
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
| 222 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class T5RelativeEmbedding(nn.Module):
|
| 227 |
+
|
| 228 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 229 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 230 |
+
self.num_buckets = num_buckets
|
| 231 |
+
self.num_heads = num_heads
|
| 232 |
+
self.bidirectional = bidirectional
|
| 233 |
+
self.max_dist = max_dist
|
| 234 |
+
|
| 235 |
+
# layers
|
| 236 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 237 |
+
|
| 238 |
+
def forward(self, lq, lk):
|
| 239 |
+
device = self.embedding.weight.device
|
| 240 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 241 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 242 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
| 243 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
| 244 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 245 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 246 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
| 247 |
+
0) # [1, N, Lq, Lk]
|
| 248 |
+
return rel_pos_embeds.contiguous()
|
| 249 |
+
|
| 250 |
+
def _relative_position_bucket(self, rel_pos):
|
| 251 |
+
# preprocess
|
| 252 |
+
if self.bidirectional:
|
| 253 |
+
num_buckets = self.num_buckets // 2
|
| 254 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 255 |
+
rel_pos = torch.abs(rel_pos)
|
| 256 |
+
else:
|
| 257 |
+
num_buckets = self.num_buckets
|
| 258 |
+
rel_buckets = 0
|
| 259 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 260 |
+
|
| 261 |
+
# embeddings for small and large positions
|
| 262 |
+
max_exact = num_buckets // 2
|
| 263 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
| 264 |
+
math.log(self.max_dist / max_exact) *
|
| 265 |
+
(num_buckets - max_exact)).long()
|
| 266 |
+
rel_pos_large = torch.min(
|
| 267 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
| 268 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 269 |
+
return rel_buckets
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class T5Encoder(nn.Module):
|
| 273 |
+
|
| 274 |
+
def __init__(self,
|
| 275 |
+
vocab,
|
| 276 |
+
dim,
|
| 277 |
+
dim_attn,
|
| 278 |
+
dim_ffn,
|
| 279 |
+
num_heads,
|
| 280 |
+
num_layers,
|
| 281 |
+
num_buckets,
|
| 282 |
+
shared_pos=True,
|
| 283 |
+
dropout=0.1):
|
| 284 |
+
super(T5Encoder, self).__init__()
|
| 285 |
+
self.dim = dim
|
| 286 |
+
self.dim_attn = dim_attn
|
| 287 |
+
self.dim_ffn = dim_ffn
|
| 288 |
+
self.num_heads = num_heads
|
| 289 |
+
self.num_layers = num_layers
|
| 290 |
+
self.num_buckets = num_buckets
|
| 291 |
+
self.shared_pos = shared_pos
|
| 292 |
+
|
| 293 |
+
# layers
|
| 294 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 295 |
+
else nn.Embedding(vocab, dim)
|
| 296 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 297 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
| 298 |
+
self.dropout = nn.Dropout(dropout)
|
| 299 |
+
self.blocks = nn.ModuleList([
|
| 300 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 301 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 302 |
+
])
|
| 303 |
+
self.norm = T5LayerNorm(dim)
|
| 304 |
+
|
| 305 |
+
# initialize weights
|
| 306 |
+
self.apply(init_weights)
|
| 307 |
+
|
| 308 |
+
def forward(self, ids, mask=None):
|
| 309 |
+
x = self.token_embedding(ids)
|
| 310 |
+
x = self.dropout(x)
|
| 311 |
+
e = self.pos_embedding(x.size(1),
|
| 312 |
+
x.size(1)) if self.shared_pos else None
|
| 313 |
+
for block in self.blocks:
|
| 314 |
+
x = block(x, mask, pos_bias=e)
|
| 315 |
+
x = self.norm(x)
|
| 316 |
+
x = self.dropout(x)
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class T5Decoder(nn.Module):
|
| 321 |
+
|
| 322 |
+
def __init__(self,
|
| 323 |
+
vocab,
|
| 324 |
+
dim,
|
| 325 |
+
dim_attn,
|
| 326 |
+
dim_ffn,
|
| 327 |
+
num_heads,
|
| 328 |
+
num_layers,
|
| 329 |
+
num_buckets,
|
| 330 |
+
shared_pos=True,
|
| 331 |
+
dropout=0.1):
|
| 332 |
+
super(T5Decoder, self).__init__()
|
| 333 |
+
self.dim = dim
|
| 334 |
+
self.dim_attn = dim_attn
|
| 335 |
+
self.dim_ffn = dim_ffn
|
| 336 |
+
self.num_heads = num_heads
|
| 337 |
+
self.num_layers = num_layers
|
| 338 |
+
self.num_buckets = num_buckets
|
| 339 |
+
self.shared_pos = shared_pos
|
| 340 |
+
|
| 341 |
+
# layers
|
| 342 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 343 |
+
else nn.Embedding(vocab, dim)
|
| 344 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 345 |
+
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
| 346 |
+
self.dropout = nn.Dropout(dropout)
|
| 347 |
+
self.blocks = nn.ModuleList([
|
| 348 |
+
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 349 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 350 |
+
])
|
| 351 |
+
self.norm = T5LayerNorm(dim)
|
| 352 |
+
|
| 353 |
+
# initialize weights
|
| 354 |
+
self.apply(init_weights)
|
| 355 |
+
|
| 356 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
| 357 |
+
b, s = ids.size()
|
| 358 |
+
|
| 359 |
+
# causal mask
|
| 360 |
+
if mask is None:
|
| 361 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
| 362 |
+
elif mask.ndim == 2:
|
| 363 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
| 364 |
+
|
| 365 |
+
# layers
|
| 366 |
+
x = self.token_embedding(ids)
|
| 367 |
+
x = self.dropout(x)
|
| 368 |
+
e = self.pos_embedding(x.size(1),
|
| 369 |
+
x.size(1)) if self.shared_pos else None
|
| 370 |
+
for block in self.blocks:
|
| 371 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
| 372 |
+
x = self.norm(x)
|
| 373 |
+
x = self.dropout(x)
|
| 374 |
+
return x
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class T5Model(nn.Module):
|
| 378 |
+
|
| 379 |
+
def __init__(self,
|
| 380 |
+
vocab_size,
|
| 381 |
+
dim,
|
| 382 |
+
dim_attn,
|
| 383 |
+
dim_ffn,
|
| 384 |
+
num_heads,
|
| 385 |
+
encoder_layers,
|
| 386 |
+
decoder_layers,
|
| 387 |
+
num_buckets,
|
| 388 |
+
shared_pos=True,
|
| 389 |
+
dropout=0.1):
|
| 390 |
+
super(T5Model, self).__init__()
|
| 391 |
+
self.vocab_size = vocab_size
|
| 392 |
+
self.dim = dim
|
| 393 |
+
self.dim_attn = dim_attn
|
| 394 |
+
self.dim_ffn = dim_ffn
|
| 395 |
+
self.num_heads = num_heads
|
| 396 |
+
self.encoder_layers = encoder_layers
|
| 397 |
+
self.decoder_layers = decoder_layers
|
| 398 |
+
self.num_buckets = num_buckets
|
| 399 |
+
|
| 400 |
+
# layers
|
| 401 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
| 402 |
+
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 403 |
+
num_heads, encoder_layers, num_buckets,
|
| 404 |
+
shared_pos, dropout)
|
| 405 |
+
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 406 |
+
num_heads, decoder_layers, num_buckets,
|
| 407 |
+
shared_pos, dropout)
|
| 408 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 409 |
+
|
| 410 |
+
# initialize weights
|
| 411 |
+
self.apply(init_weights)
|
| 412 |
+
|
| 413 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
| 414 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
| 415 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
| 416 |
+
x = self.head(x)
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _t5(name,
|
| 421 |
+
encoder_only=False,
|
| 422 |
+
decoder_only=False,
|
| 423 |
+
return_tokenizer=False,
|
| 424 |
+
tokenizer_kwargs={},
|
| 425 |
+
dtype=torch.float32,
|
| 426 |
+
device='cpu',
|
| 427 |
+
**kwargs):
|
| 428 |
+
# sanity check
|
| 429 |
+
assert not (encoder_only and decoder_only)
|
| 430 |
+
|
| 431 |
+
# params
|
| 432 |
+
if encoder_only:
|
| 433 |
+
model_cls = T5Encoder
|
| 434 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 435 |
+
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
| 436 |
+
_ = kwargs.pop('decoder_layers')
|
| 437 |
+
elif decoder_only:
|
| 438 |
+
model_cls = T5Decoder
|
| 439 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 440 |
+
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
| 441 |
+
_ = kwargs.pop('encoder_layers')
|
| 442 |
+
else:
|
| 443 |
+
model_cls = T5Model
|
| 444 |
+
|
| 445 |
+
# init model
|
| 446 |
+
with torch.device(device):
|
| 447 |
+
model = model_cls(**kwargs)
|
| 448 |
+
|
| 449 |
+
# set device
|
| 450 |
+
model = model.to(dtype=dtype, device=device)
|
| 451 |
+
|
| 452 |
+
# init tokenizer
|
| 453 |
+
if return_tokenizer:
|
| 454 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 455 |
+
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
| 456 |
+
return model, tokenizer
|
| 457 |
+
else:
|
| 458 |
+
return model
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def umt5_xxl(**kwargs):
|
| 462 |
+
cfg = dict(
|
| 463 |
+
vocab_size=256384,
|
| 464 |
+
dim=4096,
|
| 465 |
+
dim_attn=4096,
|
| 466 |
+
dim_ffn=10240,
|
| 467 |
+
num_heads=64,
|
| 468 |
+
encoder_layers=24,
|
| 469 |
+
decoder_layers=24,
|
| 470 |
+
num_buckets=32,
|
| 471 |
+
shared_pos=False,
|
| 472 |
+
dropout=0.1)
|
| 473 |
+
cfg.update(**kwargs)
|
| 474 |
+
return _t5('umt5-xxl', **cfg)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class T5EncoderModel:
|
| 478 |
+
|
| 479 |
+
def __init__(
|
| 480 |
+
self,
|
| 481 |
+
text_len,
|
| 482 |
+
dtype=torch.bfloat16,
|
| 483 |
+
device=torch.cuda.current_device(),
|
| 484 |
+
checkpoint_path=None,
|
| 485 |
+
tokenizer_path=None,
|
| 486 |
+
shard_fn=None,
|
| 487 |
+
quant=None,
|
| 488 |
+
quant_dir=None
|
| 489 |
+
):
|
| 490 |
+
assert quant is None or quant in ("int8", "fp8")
|
| 491 |
+
self.text_len = text_len
|
| 492 |
+
self.dtype = dtype
|
| 493 |
+
self.device = device
|
| 494 |
+
self.checkpoint_path = checkpoint_path
|
| 495 |
+
self.tokenizer_path = tokenizer_path
|
| 496 |
+
|
| 497 |
+
# init model
|
| 498 |
+
logging.info(f'loading {checkpoint_path}')
|
| 499 |
+
if quant is not None:
|
| 500 |
+
with torch.device('meta'):
|
| 501 |
+
model = umt5_xxl(
|
| 502 |
+
encoder_only=True,
|
| 503 |
+
return_tokenizer=False,
|
| 504 |
+
dtype=dtype,
|
| 505 |
+
device=torch.device('meta'))
|
| 506 |
+
logging.info(f'Loading quantized T5 from {os.path.join(quant_dir, f"t5_{quant}.safetensors")}')
|
| 507 |
+
model_state_dict = load_file(os.path.join(quant_dir, f"t5_{quant}.safetensors"))
|
| 508 |
+
with open(os.path.join(quant_dir, f"t5_map_{quant}.json"), "r") as f:
|
| 509 |
+
quantization_map = json.load(f)
|
| 510 |
+
requantize(model, model_state_dict, quantization_map, device='cpu')
|
| 511 |
+
else:
|
| 512 |
+
model = umt5_xxl(
|
| 513 |
+
encoder_only=True,
|
| 514 |
+
return_tokenizer=False,
|
| 515 |
+
dtype=dtype,
|
| 516 |
+
device=device).eval().requires_grad_(False)
|
| 517 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
| 518 |
+
self.model = model
|
| 519 |
+
self.model.eval().requires_grad_(False)
|
| 520 |
+
if shard_fn is not None:
|
| 521 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
| 522 |
+
else:
|
| 523 |
+
self.model.to(self.device)
|
| 524 |
+
# init tokenizer
|
| 525 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 526 |
+
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
| 527 |
+
|
| 528 |
+
def __call__(self, texts, device):
|
| 529 |
+
ids, mask = self.tokenizer(
|
| 530 |
+
texts, return_mask=True, add_special_tokens=True)
|
| 531 |
+
ids = ids.to(device)
|
| 532 |
+
mask = mask.to(device)
|
| 533 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 534 |
+
context = self.model(ids, mask)
|
| 535 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|