Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
c8df52d
1
Parent(s):
123eeba
add app!
Browse files- .gitignore +4 -0
- README.md +5 -3
- app.py +651 -0
- concurrency_manager.py +203 -0
- index.html +2130 -0
- models/__init__.py +5 -0
- models/autoencoder_kl_wan.py +1467 -0
- models/reconstruction_model.py +261 -0
- models/render.py +138 -0
- models/transformer_wan.py +601 -0
- quant.py +195 -0
- requirements.txt +19 -0
- utils.py +531 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tmpfiles/
|
| 2 |
+
model.ckpt
|
| 3 |
+
|
| 4 |
+
**/__pycache__/**
|
README.md
CHANGED
|
@@ -38,6 +38,8 @@ pip install torch torchvision
|
|
| 38 |
pip install triton transformers pytorch_lightning omegaconf ninja numpy jaxtyping rich tensorboard einops moviepy==1.0.3 webdataset accelerate opencv-python lpips av plyfile ftfy peft tensorboard pandas flask
|
| 39 |
```
|
| 40 |
|
|
|
|
|
|
|
| 41 |
- install ```gsplat@1.5.2``` and ```diffusers@wan-5Bi2v``` packages
|
| 42 |
```
|
| 43 |
pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
|
|
@@ -55,9 +57,10 @@ cd FlashWorld
|
|
| 55 |
python app.py
|
| 56 |
```
|
| 57 |
|
| 58 |
-
Then,
|
| 59 |
-
|
| 60 |
|
|
|
|
|
|
|
| 61 |
## More Generation Results
|
| 62 |
|
| 63 |
[https://github.com/user-attachments/assets/bbdbe5de-5e15-4471-b380-4d8191688d82](https://github.com/user-attachments/assets/53d41748-4c35-48c4-9771-f458421c0b38)
|
|
@@ -67,7 +70,6 @@ Then, enjoy your journey in FlashWorld!
|
|
| 67 |
|
| 68 |
Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)
|
| 69 |
|
| 70 |
-
|
| 71 |
The code is released for academic research use only.
|
| 72 |
|
| 73 |
If you have any questions, please contact me via [imlixinyang@gmail.com](mailto:imlixinyang@gmail.com).
|
|
|
|
| 38 |
pip install triton transformers pytorch_lightning omegaconf ninja numpy jaxtyping rich tensorboard einops moviepy==1.0.3 webdataset accelerate opencv-python lpips av plyfile ftfy peft tensorboard pandas flask
|
| 39 |
```
|
| 40 |
|
| 41 |
+
Please refer to the `requirements.txt` file for the exact package versions.
|
| 42 |
+
|
| 43 |
- install ```gsplat@1.5.2``` and ```diffusers@wan-5Bi2v``` packages
|
| 44 |
```
|
| 45 |
pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
|
|
|
|
| 57 |
python app.py
|
| 58 |
```
|
| 59 |
|
| 60 |
+
Then, open your web browser and navigate to ```http://HOST_IP:7860``` to start exploring FlashWorld!
|
|
|
|
| 61 |
|
| 62 |
+
<!-- We also provide example trajectory josn files and input images in the `examples/` directory. -->
|
| 63 |
+
|
| 64 |
## More Generation Results
|
| 65 |
|
| 66 |
[https://github.com/user-attachments/assets/bbdbe5de-5e15-4471-b380-4d8191688d82](https://github.com/user-attachments/assets/53d41748-4c35-48c4-9771-f458421c0b38)
|
|
|
|
| 70 |
|
| 71 |
Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)
|
| 72 |
|
|
|
|
| 73 |
The code is released for academic research use only.
|
| 74 |
|
| 75 |
If you have any questions, please contact me via [imlixinyang@gmail.com](mailto:imlixinyang@gmail.com).
|
app.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import spaces
|
| 3 |
+
GPU = spaces.GPU
|
| 4 |
+
print("spaces GPU is available")
|
| 5 |
+
except ImportError:
|
| 6 |
+
def GPU(func):
|
| 7 |
+
return func
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
|
| 12 |
+
# def install_cuda_toolkit():
|
| 13 |
+
# # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
| 14 |
+
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
|
| 15 |
+
# CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
|
| 16 |
+
# subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
|
| 17 |
+
# subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
|
| 18 |
+
# subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
|
| 19 |
+
|
| 20 |
+
# os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
| 21 |
+
# os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
|
| 22 |
+
# os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
|
| 23 |
+
# os.environ["CUDA_HOME"],
|
| 24 |
+
# "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
|
| 25 |
+
# )
|
| 26 |
+
# # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
|
| 27 |
+
# os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
| 28 |
+
|
| 29 |
+
# print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
|
| 30 |
+
|
| 31 |
+
# subprocess.call('rm /usr/bin/gcc', shell=True)
|
| 32 |
+
# subprocess.call('rm /usr/bin/g++', shell=True)
|
| 33 |
+
# subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
|
| 34 |
+
# subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
|
| 35 |
+
|
| 36 |
+
# subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
|
| 37 |
+
# subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
|
| 38 |
+
|
| 39 |
+
# subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
|
| 40 |
+
# subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
|
| 41 |
+
|
| 42 |
+
# subprocess.call('gcc --version', shell=True)
|
| 43 |
+
# subprocess.call('g++ --version', shell=True)
|
| 44 |
+
|
| 45 |
+
# install_cuda_toolkit()
|
| 46 |
+
|
| 47 |
+
# subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
|
| 48 |
+
|
| 49 |
+
from flask import Flask, jsonify, request, send_file, render_template
|
| 50 |
+
import base64
|
| 51 |
+
import io
|
| 52 |
+
from PIL import Image
|
| 53 |
+
import torch
|
| 54 |
+
import numpy as np
|
| 55 |
+
import os
|
| 56 |
+
import argparse
|
| 57 |
+
import imageio
|
| 58 |
+
import json
|
| 59 |
+
|
| 60 |
+
import time
|
| 61 |
+
import threading
|
| 62 |
+
|
| 63 |
+
from concurrency_manager import ConcurrencyManager
|
| 64 |
+
|
| 65 |
+
from huggingface_hub import hf_hub_download
|
| 66 |
+
|
| 67 |
+
import einops
|
| 68 |
+
import torch
|
| 69 |
+
import torch.nn as nn
|
| 70 |
+
import torch.nn.functional as F
|
| 71 |
+
import numpy as np
|
| 72 |
+
|
| 73 |
+
import imageio
|
| 74 |
+
|
| 75 |
+
from models import *
|
| 76 |
+
from utils import *
|
| 77 |
+
|
| 78 |
+
from transformers import T5TokenizerFast, UMT5EncoderModel
|
| 79 |
+
|
| 80 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 81 |
+
|
| 82 |
+
class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
| 83 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 84 |
+
if schedule_timesteps is None:
|
| 85 |
+
schedule_timesteps = self.timesteps
|
| 86 |
+
|
| 87 |
+
return torch.argmin(
|
| 88 |
+
(timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
|
| 89 |
+
|
| 90 |
+
class GenerationSystem(nn.Module):
|
| 91 |
+
def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.device = device
|
| 94 |
+
self.offload_t5 = offload_t5
|
| 95 |
+
self.offload_vae = offload_vae
|
| 96 |
+
|
| 97 |
+
self.latent_dim = 48
|
| 98 |
+
self.temporal_downsample_factor = 4
|
| 99 |
+
self.spatial_downsample_factor = 16
|
| 100 |
+
|
| 101 |
+
self.feat_dim = 1024
|
| 102 |
+
|
| 103 |
+
self.latent_patch_size = 2
|
| 104 |
+
|
| 105 |
+
self.denoising_steps = [0, 250, 500, 750]
|
| 106 |
+
|
| 107 |
+
model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
| 108 |
+
|
| 109 |
+
self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
|
| 110 |
+
|
| 111 |
+
from models.autoencoder_kl_wan import WanCausalConv3d
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
for name, module in self.vae.named_modules():
|
| 114 |
+
if isinstance(module, WanCausalConv3d):
|
| 115 |
+
time_pad = module._padding[4]
|
| 116 |
+
module.padding = (0, module._padding[2], module._padding[0])
|
| 117 |
+
module._padding = (0, 0, 0, 0, 0, 0)
|
| 118 |
+
module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
|
| 119 |
+
|
| 120 |
+
self.vae.requires_grad_(False)
|
| 121 |
+
|
| 122 |
+
self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
|
| 123 |
+
self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
|
| 124 |
+
|
| 125 |
+
self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std
|
| 126 |
+
self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean
|
| 127 |
+
|
| 128 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
|
| 129 |
+
|
| 130 |
+
self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
|
| 131 |
+
|
| 132 |
+
self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
|
| 133 |
+
|
| 134 |
+
self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
|
| 135 |
+
# self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1]
|
| 136 |
+
|
| 137 |
+
weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
|
| 138 |
+
bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
|
| 139 |
+
|
| 140 |
+
extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
|
| 141 |
+
extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
|
| 142 |
+
|
| 143 |
+
self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
|
| 144 |
+
self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
|
| 145 |
+
|
| 146 |
+
self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
|
| 147 |
+
|
| 148 |
+
self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
|
| 149 |
+
|
| 150 |
+
self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
|
| 151 |
+
|
| 152 |
+
self.transformer.disable_gradient_checkpointing()
|
| 153 |
+
self.transformer.gradient_checkpointing = False
|
| 154 |
+
|
| 155 |
+
self.add_feedback_for_transformer()
|
| 156 |
+
|
| 157 |
+
if ckpt_path is not None:
|
| 158 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 159 |
+
self.transformer.load_state_dict(state_dict["transformer"])
|
| 160 |
+
self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
|
| 161 |
+
print(f"Loaded {ckpt_path}.")
|
| 162 |
+
|
| 163 |
+
from quant import FluxFp8GeMMProcessor
|
| 164 |
+
|
| 165 |
+
FluxFp8GeMMProcessor(self.transformer)
|
| 166 |
+
|
| 167 |
+
del self.vae.post_quant_conv, self.vae.decoder
|
| 168 |
+
self.vae.to(self.device if not self.offload_vae else "cpu")
|
| 169 |
+
|
| 170 |
+
self.transformer.to(self.device)
|
| 171 |
+
|
| 172 |
+
def add_feedback_for_transformer(self):
|
| 173 |
+
self.use_feedback = True
|
| 174 |
+
self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
|
| 175 |
+
|
| 176 |
+
def encode_text(self, texts):
|
| 177 |
+
max_sequence_length = 512
|
| 178 |
+
|
| 179 |
+
text_inputs = self.tokenizer(
|
| 180 |
+
texts,
|
| 181 |
+
padding="max_length",
|
| 182 |
+
max_length=max_sequence_length,
|
| 183 |
+
truncation=True,
|
| 184 |
+
add_special_tokens=True,
|
| 185 |
+
return_attention_mask=True,
|
| 186 |
+
return_tensors="pt",
|
| 187 |
+
)
|
| 188 |
+
if getattr(self, "offload_t5", False):
|
| 189 |
+
text_input_ids = text_inputs.input_ids.to("cpu")
|
| 190 |
+
mask = text_inputs.attention_mask.to("cpu")
|
| 191 |
+
else:
|
| 192 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
| 193 |
+
mask = text_inputs.attention_mask.to(self.device)
|
| 194 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 195 |
+
|
| 196 |
+
if getattr(self, "offload_t5", False):
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
|
| 199 |
+
else:
|
| 200 |
+
text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
|
| 201 |
+
text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
|
| 202 |
+
text_embeds = torch.stack(
|
| 203 |
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
|
| 204 |
+
)
|
| 205 |
+
return text_embeds.float()
|
| 206 |
+
|
| 207 |
+
def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
|
| 208 |
+
|
| 209 |
+
out = self.transformer(
|
| 210 |
+
hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
|
| 211 |
+
timestep=t,
|
| 212 |
+
encoder_hidden_states=text_embeds,
|
| 213 |
+
return_dict=False,
|
| 214 |
+
)[0]
|
| 215 |
+
|
| 216 |
+
v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
|
| 217 |
+
|
| 218 |
+
sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
|
| 219 |
+
latents_pred_2d = noisy_latents - sigma * v_pred
|
| 220 |
+
|
| 221 |
+
if need_3d_mode:
|
| 222 |
+
scene_params = self.recon_decoder(
|
| 223 |
+
einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
|
| 224 |
+
einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
|
| 225 |
+
cameras
|
| 226 |
+
).flatten(1, -2)
|
| 227 |
+
|
| 228 |
+
images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
|
| 229 |
+
|
| 230 |
+
latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
|
| 231 |
+
einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
|
| 232 |
+
).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
'2d': latents_pred_2d,
|
| 236 |
+
'3d': latents_pred_3d if need_3d_mode else None,
|
| 237 |
+
'rgb_3d': images_pred if need_3d_mode else None,
|
| 238 |
+
'scene': scene_params if need_3d_mode else None,
|
| 239 |
+
'feat': feats
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
@torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
|
| 244 |
+
def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
batch_size = 1
|
| 247 |
+
|
| 248 |
+
cameras = cameras.to(self.device).unsqueeze(0)
|
| 249 |
+
|
| 250 |
+
if cameras.shape[1] != n_frame:
|
| 251 |
+
render_cameras = cameras.clone()
|
| 252 |
+
cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
|
| 253 |
+
else:
|
| 254 |
+
render_cameras = cameras
|
| 255 |
+
|
| 256 |
+
cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
|
| 257 |
+
|
| 258 |
+
render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
|
| 259 |
+
|
| 260 |
+
text = "[Static] " + text
|
| 261 |
+
|
| 262 |
+
text_embeds = self.encode_text([text])
|
| 263 |
+
# neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1)
|
| 264 |
+
|
| 265 |
+
masks = torch.zeros(batch_size, n_frame, device=self.device)
|
| 266 |
+
|
| 267 |
+
condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
|
| 268 |
+
|
| 269 |
+
if image is not None:
|
| 270 |
+
image = image.to(self.device)
|
| 271 |
+
|
| 272 |
+
latent = self.latent_scale_fn(self.vae.encode(
|
| 273 |
+
image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
|
| 274 |
+
).latent_dist.sample().to(self.device)).squeeze(2)
|
| 275 |
+
|
| 276 |
+
masks[:, image_index] = 1
|
| 277 |
+
condition_latents[:, :, image_index] = latent
|
| 278 |
+
|
| 279 |
+
raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
|
| 280 |
+
raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
|
| 281 |
+
|
| 282 |
+
noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
|
| 283 |
+
|
| 284 |
+
noisy_latents = noise
|
| 285 |
+
|
| 286 |
+
torch.cuda.empty_cache()
|
| 287 |
+
|
| 288 |
+
if self.use_feedback:
|
| 289 |
+
prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
|
| 290 |
+
|
| 291 |
+
prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
|
| 292 |
+
|
| 293 |
+
for i in range(len(self.denoising_steps)):
|
| 294 |
+
t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
|
| 295 |
+
|
| 296 |
+
t = self.timesteps[t_ids]
|
| 297 |
+
|
| 298 |
+
if self.use_feedback:
|
| 299 |
+
_condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
|
| 300 |
+
else:
|
| 301 |
+
_condition_latents = condition_latents
|
| 302 |
+
|
| 303 |
+
if i < len(self.denoising_steps) - 1:
|
| 304 |
+
out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
|
| 305 |
+
|
| 306 |
+
latents_pred = out["3d"]
|
| 307 |
+
|
| 308 |
+
if self.use_feedback:
|
| 309 |
+
prev_latents_pred = latents_pred
|
| 310 |
+
prev_feats = out['feat']
|
| 311 |
+
|
| 312 |
+
noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
out = self.transformer(
|
| 316 |
+
hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
|
| 317 |
+
timestep=t,
|
| 318 |
+
encoder_hidden_states=text_embeds,
|
| 319 |
+
return_dict=False,
|
| 320 |
+
)[0]
|
| 321 |
+
|
| 322 |
+
v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
|
| 323 |
+
|
| 324 |
+
sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
|
| 325 |
+
latents_pred = noisy_latents - sigma * v_pred
|
| 326 |
+
|
| 327 |
+
scene_params = self.recon_decoder(
|
| 328 |
+
einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
|
| 329 |
+
einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
|
| 330 |
+
cameras
|
| 331 |
+
).flatten(1, -2)
|
| 332 |
+
|
| 333 |
+
if video_output_path is not None:
|
| 334 |
+
interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
|
| 335 |
+
|
| 336 |
+
interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
|
| 337 |
+
|
| 338 |
+
interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
|
| 339 |
+
|
| 340 |
+
imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
|
| 341 |
+
|
| 342 |
+
scene_params = scene_params[0]
|
| 343 |
+
|
| 344 |
+
scene_params = scene_params.detach().cpu()
|
| 345 |
+
|
| 346 |
+
return scene_params, ref_w2c, T_norm
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
parser = argparse.ArgumentParser()
|
| 350 |
+
parser.add_argument('--port', type=int, default=7860)
|
| 351 |
+
parser.add_argument("--ckpt", default=None)
|
| 352 |
+
parser.add_argument("--gpu", type=int, default=0)
|
| 353 |
+
parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
|
| 354 |
+
parser.add_argument("--offload_t5", type=bool, default=False)
|
| 355 |
+
parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
|
| 356 |
+
args, _ = parser.parse_known_args()
|
| 357 |
+
|
| 358 |
+
# Ensure model.ckpt exists, download if not present
|
| 359 |
+
if args.ckpt is None:
|
| 360 |
+
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
| 361 |
+
ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
|
| 362 |
+
if not os.path.exists(ckpt_path):
|
| 363 |
+
hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
|
| 364 |
+
else:
|
| 365 |
+
ckpt_path = args.ckpt
|
| 366 |
+
|
| 367 |
+
app = Flask(__name__)
|
| 368 |
+
|
| 369 |
+
# 初始化GenerationSystem
|
| 370 |
+
device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 371 |
+
generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
|
| 372 |
+
|
| 373 |
+
# 初始化并发管理器
|
| 374 |
+
concurrency_manager = ConcurrencyManager(max_concurrent=args.max_concurrent)
|
| 375 |
+
|
| 376 |
+
@app.after_request
|
| 377 |
+
def after_request(response):
|
| 378 |
+
response.headers.add('Access-Control-Allow-Origin', '*')
|
| 379 |
+
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
| 380 |
+
response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
|
| 381 |
+
return response
|
| 382 |
+
|
| 383 |
+
@GPU
|
| 384 |
+
def generate_wrapper(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None):
|
| 385 |
+
"""生成函数的包装器,用于并发控制"""
|
| 386 |
+
return generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path)
|
| 387 |
+
|
| 388 |
+
def job_generate(file_id, cache_dir, payload):
|
| 389 |
+
"""工作线程执行的生成任务:负责生成并落盘,返回可下载信息"""
|
| 390 |
+
# 解包参数
|
| 391 |
+
cameras = payload["cameras"]
|
| 392 |
+
n_frame = payload["n_frame"]
|
| 393 |
+
image = payload["image"]
|
| 394 |
+
text_prompt = payload["text_prompt"]
|
| 395 |
+
image_index = payload["image_index"]
|
| 396 |
+
image_height = payload["image_height"]
|
| 397 |
+
image_width = payload["image_width"]
|
| 398 |
+
data = payload["raw_request"]
|
| 399 |
+
|
| 400 |
+
# 执行生成
|
| 401 |
+
scene_params, ref_w2c, T_norm = generation_system.generate(
|
| 402 |
+
cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# 保存请求元数据
|
| 406 |
+
with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
|
| 407 |
+
json.dump(data, f)
|
| 408 |
+
|
| 409 |
+
# 导出PLY文件
|
| 410 |
+
splat_path = os.path.join(cache_dir, f'{file_id}.ply')
|
| 411 |
+
export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
|
| 412 |
+
|
| 413 |
+
file_size = os.path.getsize(splat_path) if os.path.exists(splat_path) else 0
|
| 414 |
+
|
| 415 |
+
return {
|
| 416 |
+
'file_id': file_id,
|
| 417 |
+
'file_path': splat_path,
|
| 418 |
+
'file_size': file_size,
|
| 419 |
+
'download_url': f'/download/{file_id}'
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
@app.route('/generate', methods=['POST', 'OPTIONS'])
|
| 423 |
+
def generate():
|
| 424 |
+
# Handle preflight request
|
| 425 |
+
if request.method == 'OPTIONS':
|
| 426 |
+
return jsonify({'status': 'ok'})
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
data = request.get_json(force=True)
|
| 430 |
+
|
| 431 |
+
image_prompt = data.get('image_prompt', None)
|
| 432 |
+
text_prompt = data.get('text_prompt', "")
|
| 433 |
+
cameras = data.get('cameras')
|
| 434 |
+
resolution = data.get('resolution')
|
| 435 |
+
image_index = data.get('image_index', 0)
|
| 436 |
+
|
| 437 |
+
n_frame, image_height, image_width = resolution
|
| 438 |
+
|
| 439 |
+
if not image_prompt and text_prompt == "":
|
| 440 |
+
return jsonify({'error': 'No Prompts provided'}), 400
|
| 441 |
+
|
| 442 |
+
# 处理图像
|
| 443 |
+
if image_prompt:
|
| 444 |
+
# image_prompt可以是路径和base64
|
| 445 |
+
if os.path.exists(image_prompt):
|
| 446 |
+
image_prompt = Image.open(image_prompt)
|
| 447 |
+
else:
|
| 448 |
+
# image_prompt 可能是 "data:image/png;base64,...."
|
| 449 |
+
if ',' in image_prompt:
|
| 450 |
+
image_prompt = image_prompt.split(',', 1)[1]
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
image_bytes = base64.b64decode(image_prompt)
|
| 454 |
+
image_prompt = Image.open(io.BytesIO(image_bytes))
|
| 455 |
+
except Exception as img_e:
|
| 456 |
+
return jsonify({'error': f'Image decode error: {str(img_e)}'}), 400
|
| 457 |
+
|
| 458 |
+
image = image_prompt.convert('RGB')
|
| 459 |
+
|
| 460 |
+
w, h = image.size
|
| 461 |
+
|
| 462 |
+
# center crop
|
| 463 |
+
if image_height / h > image_width / w:
|
| 464 |
+
scale = image_height / h
|
| 465 |
+
else:
|
| 466 |
+
scale = image_width / w
|
| 467 |
+
|
| 468 |
+
new_h = int(image_height / scale)
|
| 469 |
+
new_w = int(image_width / scale)
|
| 470 |
+
|
| 471 |
+
image = image.crop(((w - new_w) // 2, (h - new_h) // 2,
|
| 472 |
+
new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height))
|
| 473 |
+
|
| 474 |
+
for camera in cameras:
|
| 475 |
+
camera['fx'] = camera['fx'] * scale
|
| 476 |
+
camera['fy'] = camera['fy'] * scale
|
| 477 |
+
camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale
|
| 478 |
+
camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale
|
| 479 |
+
|
| 480 |
+
image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1
|
| 481 |
+
else:
|
| 482 |
+
image = None
|
| 483 |
+
|
| 484 |
+
cameras = torch.stack([
|
| 485 |
+
torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32))
|
| 486 |
+
for camera in cameras
|
| 487 |
+
], dim=0)
|
| 488 |
+
|
| 489 |
+
file_id = str(int(time.time() * 1000))
|
| 490 |
+
|
| 491 |
+
# 组装任务参数,推迟执行与落盘到工作线程中
|
| 492 |
+
payload = {
|
| 493 |
+
'cameras': cameras,
|
| 494 |
+
'n_frame': n_frame,
|
| 495 |
+
'image': image,
|
| 496 |
+
'text_prompt': text_prompt,
|
| 497 |
+
'image_index': image_index,
|
| 498 |
+
'image_height': image_height,
|
| 499 |
+
'image_width': image_width,
|
| 500 |
+
'raw_request': data,
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
# 提交任务到并发管理器(异步)
|
| 504 |
+
task_id = concurrency_manager.submit_task(
|
| 505 |
+
job_generate, file_id, args.cache_dir, payload
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# 提交后立即返回队列信息
|
| 509 |
+
queue_status = concurrency_manager.get_queue_status()
|
| 510 |
+
queued_tasks = queue_status.get('queued_tasks', [])
|
| 511 |
+
try:
|
| 512 |
+
queue_position = queued_tasks.index(task_id) + 1
|
| 513 |
+
except ValueError:
|
| 514 |
+
# 如果任务已被工作线程立即领取,则认为已开始执行,位置为 0
|
| 515 |
+
queue_position = 0
|
| 516 |
+
|
| 517 |
+
return jsonify({
|
| 518 |
+
'success': True,
|
| 519 |
+
'task_id': task_id,
|
| 520 |
+
'file_id': file_id,
|
| 521 |
+
'queue': {
|
| 522 |
+
'queued_count': queue_status.get('queued_count', 0),
|
| 523 |
+
'running_count': queue_status.get('running_count', 0),
|
| 524 |
+
'position': queue_position
|
| 525 |
+
}
|
| 526 |
+
}), 202
|
| 527 |
+
|
| 528 |
+
except Exception as e:
|
| 529 |
+
return jsonify({'error': f'Server error: {str(e)}'}), 500
|
| 530 |
+
|
| 531 |
+
@app.route('/download/<file_id>', methods=['GET'])
|
| 532 |
+
def download_file(file_id):
|
| 533 |
+
"""下载生成的PLY文件"""
|
| 534 |
+
file_path = os.path.join(args.cache_dir, f'{file_id}.ply')
|
| 535 |
+
|
| 536 |
+
if not os.path.exists(file_path):
|
| 537 |
+
return jsonify({'error': 'File not found'}), 404
|
| 538 |
+
|
| 539 |
+
return send_file(file_path, as_attachment=True, download_name=f'{file_id}.ply')
|
| 540 |
+
|
| 541 |
+
@app.route('/delete/<file_id>', methods=['DELETE', 'POST', 'OPTIONS'])
|
| 542 |
+
def delete_file_endpoint(file_id):
|
| 543 |
+
"""删除生成的文件及其元数据(由前端在下载完成后调用)"""
|
| 544 |
+
# CORS preflight
|
| 545 |
+
if request.method == 'OPTIONS':
|
| 546 |
+
return jsonify({'status': 'ok'})
|
| 547 |
+
|
| 548 |
+
try:
|
| 549 |
+
ply_path = os.path.join(args.cache_dir, f'{file_id}.ply')
|
| 550 |
+
json_path = os.path.join(args.cache_dir, f'{file_id}.json')
|
| 551 |
+
deleted = []
|
| 552 |
+
for path in [ply_path, json_path]:
|
| 553 |
+
if os.path.exists(path):
|
| 554 |
+
os.remove(path)
|
| 555 |
+
deleted.append(os.path.basename(path))
|
| 556 |
+
return jsonify({'success': True, 'deleted': deleted})
|
| 557 |
+
except Exception as e:
|
| 558 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 559 |
+
|
| 560 |
+
@app.route('/status', methods=['GET'])
|
| 561 |
+
def get_status():
|
| 562 |
+
"""获取系统状态和队列信息"""
|
| 563 |
+
try:
|
| 564 |
+
queue_status = concurrency_manager.get_queue_status()
|
| 565 |
+
return jsonify({
|
| 566 |
+
'success': True,
|
| 567 |
+
'status': queue_status,
|
| 568 |
+
'timestamp': time.time()
|
| 569 |
+
})
|
| 570 |
+
except Exception as e:
|
| 571 |
+
return jsonify({'error': f'Failed to get status: {str(e)}'}), 500
|
| 572 |
+
|
| 573 |
+
@app.route('/task/<task_id>', methods=['GET'])
|
| 574 |
+
def get_task_status(task_id):
|
| 575 |
+
"""获取特定任务的状态(包含排队位置和完成后的文件信息)"""
|
| 576 |
+
try:
|
| 577 |
+
task = concurrency_manager.get_task_status(task_id)
|
| 578 |
+
if not task:
|
| 579 |
+
return jsonify({'error': 'Task not found'}), 404
|
| 580 |
+
|
| 581 |
+
queue_status = concurrency_manager.get_queue_status()
|
| 582 |
+
queued_tasks = queue_status.get('queued_tasks', [])
|
| 583 |
+
try:
|
| 584 |
+
queue_position = queued_tasks.index(task_id) + 1
|
| 585 |
+
except ValueError:
|
| 586 |
+
queue_position = 0
|
| 587 |
+
|
| 588 |
+
resp = {
|
| 589 |
+
'success': True,
|
| 590 |
+
'task_id': task_id,
|
| 591 |
+
'status': task.status.value,
|
| 592 |
+
'created_at': task.created_at,
|
| 593 |
+
'started_at': task.started_at,
|
| 594 |
+
'completed_at': task.completed_at,
|
| 595 |
+
'error': task.error,
|
| 596 |
+
'queue': {
|
| 597 |
+
'queued_count': queue_status.get('queued_count', 0),
|
| 598 |
+
'running_count': queue_status.get('running_count', 0),
|
| 599 |
+
'position': queue_position
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
if task.status.value == 'completed' and isinstance(task.result, dict):
|
| 604 |
+
resp.update({
|
| 605 |
+
'file_id': task.result.get('file_id'),
|
| 606 |
+
'file_path': task.result.get('file_path'),
|
| 607 |
+
'file_size': task.result.get('file_size'),
|
| 608 |
+
'download_url': task.result.get('download_url'),
|
| 609 |
+
'generation_time': (task.completed_at - task.started_at)
|
| 610 |
+
})
|
| 611 |
+
|
| 612 |
+
# 更新task状态
|
| 613 |
+
|
| 614 |
+
return jsonify(resp)
|
| 615 |
+
except Exception as e:
|
| 616 |
+
return jsonify({'error': f'Failed to get task status: {str(e)}'}), 500
|
| 617 |
+
|
| 618 |
+
@app.route("/")
|
| 619 |
+
def index():
|
| 620 |
+
return send_file("index.html")
|
| 621 |
+
|
| 622 |
+
os.makedirs(args.cache_dir, exist_ok=True)
|
| 623 |
+
|
| 624 |
+
# 后台定时清理:删除超过30分钟未访问/修改的缓存文件
|
| 625 |
+
def cleanup_worker(cache_dir: str, max_age_seconds: int = 1800, interval_seconds: int = 300):
|
| 626 |
+
while True:
|
| 627 |
+
try:
|
| 628 |
+
now = time.time()
|
| 629 |
+
for name in os.listdir(cache_dir):
|
| 630 |
+
# 只清理与任务相关的 .ply/.json 文件
|
| 631 |
+
if not (name.endswith('.ply') or name.endswith('.json')):
|
| 632 |
+
continue
|
| 633 |
+
path = os.path.join(cache_dir, name)
|
| 634 |
+
try:
|
| 635 |
+
mtime = os.path.getmtime(path)
|
| 636 |
+
if now - mtime > max_age_seconds:
|
| 637 |
+
os.remove(path)
|
| 638 |
+
except FileNotFoundError:
|
| 639 |
+
pass
|
| 640 |
+
except Exception:
|
| 641 |
+
# 忽略单个文件的异常,继续清理
|
| 642 |
+
pass
|
| 643 |
+
except Exception:
|
| 644 |
+
# 防止线程因异常退出
|
| 645 |
+
pass
|
| 646 |
+
time.sleep(interval_seconds)
|
| 647 |
+
|
| 648 |
+
cleaner_thread = threading.Thread(target=cleanup_worker, args=(args.cache_dir,), daemon=True)
|
| 649 |
+
cleaner_thread.start()
|
| 650 |
+
|
| 651 |
+
app.run(host='0.0.0.0', port=args.port)
|
concurrency_manager.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Dict, List, Optional, Callable, Any
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from enum import Enum
|
| 7 |
+
|
| 8 |
+
class TaskStatus(Enum):
|
| 9 |
+
QUEUED = "queued"
|
| 10 |
+
RUNNING = "running"
|
| 11 |
+
COMPLETED = "completed"
|
| 12 |
+
FAILED = "failed"
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Task:
|
| 16 |
+
task_id: str
|
| 17 |
+
status: TaskStatus
|
| 18 |
+
created_at: float
|
| 19 |
+
started_at: Optional[float] = None
|
| 20 |
+
completed_at: Optional[float] = None
|
| 21 |
+
result: Optional[Any] = None
|
| 22 |
+
error: Optional[str] = None
|
| 23 |
+
function: Optional[Callable] = None
|
| 24 |
+
args: tuple = ()
|
| 25 |
+
kwargs: dict = None
|
| 26 |
+
|
| 27 |
+
def __post_init__(self):
|
| 28 |
+
if self.kwargs is None:
|
| 29 |
+
self.kwargs = {}
|
| 30 |
+
|
| 31 |
+
class ConcurrencyManager:
|
| 32 |
+
def __init__(self, max_concurrent: int = 2):
|
| 33 |
+
"""
|
| 34 |
+
并发控制管理器
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
max_concurrent: 最大并发数量
|
| 38 |
+
"""
|
| 39 |
+
self.max_concurrent = max_concurrent
|
| 40 |
+
self.running_tasks: Dict[str, Task] = {}
|
| 41 |
+
self.queued_tasks: List[Task] = []
|
| 42 |
+
self.completed_tasks: Dict[str, Task] = {}
|
| 43 |
+
self.lock = threading.RLock()
|
| 44 |
+
self.worker_threads: List[threading.Thread] = []
|
| 45 |
+
self.shutdown_event = threading.Event()
|
| 46 |
+
|
| 47 |
+
# 启动工作线程
|
| 48 |
+
self._start_workers()
|
| 49 |
+
|
| 50 |
+
def _start_workers(self):
|
| 51 |
+
"""启动工作线程"""
|
| 52 |
+
for i in range(self.max_concurrent):
|
| 53 |
+
worker = threading.Thread(target=self._worker_loop, daemon=True)
|
| 54 |
+
worker.start()
|
| 55 |
+
self.worker_threads.append(worker)
|
| 56 |
+
|
| 57 |
+
def _worker_loop(self):
|
| 58 |
+
"""工作线程主循环"""
|
| 59 |
+
while not self.shutdown_event.is_set():
|
| 60 |
+
try:
|
| 61 |
+
task = self._get_next_task()
|
| 62 |
+
if task:
|
| 63 |
+
self._execute_task(task)
|
| 64 |
+
else:
|
| 65 |
+
# 没有任务时短暂休眠
|
| 66 |
+
time.sleep(0.1)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Worker thread error: {e}")
|
| 69 |
+
time.sleep(1)
|
| 70 |
+
|
| 71 |
+
def _get_next_task(self) -> Optional[Task]:
|
| 72 |
+
"""获取下一个要执行的任务"""
|
| 73 |
+
with self.lock:
|
| 74 |
+
if self.queued_tasks:
|
| 75 |
+
return self.queued_tasks.pop(0)
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def _execute_task(self, task: Task):
|
| 79 |
+
"""执行任务"""
|
| 80 |
+
try:
|
| 81 |
+
with self.lock:
|
| 82 |
+
task.status = TaskStatus.RUNNING
|
| 83 |
+
task.started_at = time.time()
|
| 84 |
+
self.running_tasks[task.task_id] = task
|
| 85 |
+
|
| 86 |
+
# 执行任务
|
| 87 |
+
if task.function:
|
| 88 |
+
result = task.function(*task.args, **task.kwargs)
|
| 89 |
+
task.result = result
|
| 90 |
+
|
| 91 |
+
# 标记完成
|
| 92 |
+
with self.lock:
|
| 93 |
+
task.status = TaskStatus.COMPLETED
|
| 94 |
+
task.completed_at = time.time()
|
| 95 |
+
self.completed_tasks[task.task_id] = task
|
| 96 |
+
if task.task_id in self.running_tasks:
|
| 97 |
+
del self.running_tasks[task.task_id]
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
# 标记失败
|
| 101 |
+
with self.lock:
|
| 102 |
+
task.status = TaskStatus.FAILED
|
| 103 |
+
task.completed_at = time.time()
|
| 104 |
+
task.error = str(e)
|
| 105 |
+
self.completed_tasks[task.task_id] = task
|
| 106 |
+
if task.task_id in self.running_tasks:
|
| 107 |
+
del self.running_tasks[task.task_id]
|
| 108 |
+
|
| 109 |
+
def submit_task(self, func: Callable, *args, **kwargs) -> str:
|
| 110 |
+
"""
|
| 111 |
+
提交任务
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
func: 要执行的函数
|
| 115 |
+
*args: 函数参数
|
| 116 |
+
**kwargs: 函数关键字参数
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
task_id: 任务ID
|
| 120 |
+
"""
|
| 121 |
+
task_id = str(uuid.uuid4())
|
| 122 |
+
task = Task(
|
| 123 |
+
task_id=task_id,
|
| 124 |
+
status=TaskStatus.QUEUED,
|
| 125 |
+
created_at=time.time(),
|
| 126 |
+
function=func,
|
| 127 |
+
args=args,
|
| 128 |
+
kwargs=kwargs
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
with self.lock:
|
| 132 |
+
self.queued_tasks.append(task)
|
| 133 |
+
|
| 134 |
+
return task_id
|
| 135 |
+
|
| 136 |
+
def get_task_status(self, task_id: str) -> Optional[Task]:
|
| 137 |
+
"""获取任务状态"""
|
| 138 |
+
with self.lock:
|
| 139 |
+
if task_id in self.running_tasks:
|
| 140 |
+
return self.running_tasks[task_id]
|
| 141 |
+
elif task_id in self.completed_tasks:
|
| 142 |
+
return self.completed_tasks[task_id]
|
| 143 |
+
else:
|
| 144 |
+
# 检查队列中的任务
|
| 145 |
+
for task in self.queued_tasks:
|
| 146 |
+
if task.task_id == task_id:
|
| 147 |
+
return task
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
def get_queue_status(self) -> Dict[str, Any]:
|
| 151 |
+
"""获取队列状态"""
|
| 152 |
+
with self.lock:
|
| 153 |
+
return {
|
| 154 |
+
"max_concurrent": self.max_concurrent,
|
| 155 |
+
"running_count": len(self.running_tasks),
|
| 156 |
+
"queued_count": len(self.queued_tasks),
|
| 157 |
+
"completed_count": len(self.completed_tasks),
|
| 158 |
+
"running_tasks": [task.task_id for task in self.running_tasks.values()],
|
| 159 |
+
"queued_tasks": [task.task_id for task in self.queued_tasks],
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Task:
|
| 163 |
+
"""
|
| 164 |
+
等待任务完成
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
task_id: 任务ID
|
| 168 |
+
timeout: 超时时间(秒),None表示无限等待
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Task: 完成的任务
|
| 172 |
+
"""
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
|
| 175 |
+
while True:
|
| 176 |
+
task = self.get_task_status(task_id)
|
| 177 |
+
if task and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
| 178 |
+
return task
|
| 179 |
+
|
| 180 |
+
if timeout and (time.time() - start_time) > timeout:
|
| 181 |
+
raise TimeoutError(f"Task {task_id} timed out after {timeout} seconds")
|
| 182 |
+
|
| 183 |
+
time.sleep(0.1)
|
| 184 |
+
|
| 185 |
+
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
| 186 |
+
"""清理旧任务"""
|
| 187 |
+
current_time = time.time()
|
| 188 |
+
max_age_seconds = max_age_hours * 3600
|
| 189 |
+
|
| 190 |
+
with self.lock:
|
| 191 |
+
# 清理已完成的任务
|
| 192 |
+
old_tasks = [
|
| 193 |
+
task_id for task_id, task in self.completed_tasks.items()
|
| 194 |
+
if current_time - task.completed_at > max_age_seconds
|
| 195 |
+
]
|
| 196 |
+
for task_id in old_tasks:
|
| 197 |
+
del self.completed_tasks[task_id]
|
| 198 |
+
|
| 199 |
+
def shutdown(self):
|
| 200 |
+
"""关闭管理器"""
|
| 201 |
+
self.shutdown_event.set()
|
| 202 |
+
for worker in self.worker_threads:
|
| 203 |
+
worker.join(timeout=5)
|
index.html
ADDED
|
@@ -0,0 +1,2130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>FlashWorld Demo</title>
|
| 7 |
+
<meta name="description" content="">
|
| 8 |
+
<style>
|
| 9 |
+
body {
|
| 10 |
+
margin: 0;
|
| 11 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 12 |
+
background: #1a1a1a;
|
| 13 |
+
color: #ffffff;
|
| 14 |
+
overflow: hidden;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
.main-container {
|
| 18 |
+
display: flex;
|
| 19 |
+
height: 100vh;
|
| 20 |
+
flex-direction: column;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
.header {
|
| 24 |
+
background: rgba(0, 0, 0, 0.8);
|
| 25 |
+
padding: 15px 20px;
|
| 26 |
+
text-align: center;
|
| 27 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
|
| 28 |
+
flex-shrink: 0;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
.header h1 {
|
| 32 |
+
margin: 0;
|
| 33 |
+
color: white;
|
| 34 |
+
font-size: 1.8em;
|
| 35 |
+
font-weight: 600;
|
| 36 |
+
margin-bottom: 8px;
|
| 37 |
+
}
|
| 38 |
+
.header-title-wrap {
|
| 39 |
+
display: inline-flex;
|
| 40 |
+
align-items: center;
|
| 41 |
+
gap: 8px;
|
| 42 |
+
position: relative;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.header-links {
|
| 46 |
+
display: flex;
|
| 47 |
+
justify-content: center;
|
| 48 |
+
gap: 20px;
|
| 49 |
+
margin-top: 8px;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
.header-links a {
|
| 53 |
+
color: #60a5fa;
|
| 54 |
+
text-decoration: none;
|
| 55 |
+
font-size: 0.9em;
|
| 56 |
+
padding: 5px 10px;
|
| 57 |
+
border: 1px solid #60a5fa;
|
| 58 |
+
border-radius: 5px;
|
| 59 |
+
transition: all 0.3s ease;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
.header-links a:hover {
|
| 63 |
+
background: #60a5fa;
|
| 64 |
+
color: white;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
.content-container {
|
| 68 |
+
display: flex;
|
| 69 |
+
flex: 1;
|
| 70 |
+
overflow: hidden;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.left-panel {
|
| 74 |
+
width: 280px;
|
| 75 |
+
background: rgba(0, 0, 0, 0.7);
|
| 76 |
+
border-right: 1px solid rgba(255, 255, 255, 0.1);
|
| 77 |
+
padding: 20px;
|
| 78 |
+
overflow-y: auto;
|
| 79 |
+
flex-shrink: 0;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.center-panel {
|
| 83 |
+
flex: 1;
|
| 84 |
+
position: relative;
|
| 85 |
+
background: #000;
|
| 86 |
+
display: flex;
|
| 87 |
+
justify-content: center;
|
| 88 |
+
align-items: center;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.right-panel {
|
| 92 |
+
width: 300px;
|
| 93 |
+
background: rgba(0, 0, 0, 0.7);
|
| 94 |
+
border-left: 1px solid rgba(255, 255, 255, 0.1);
|
| 95 |
+
padding: 20px;
|
| 96 |
+
overflow-y: auto;
|
| 97 |
+
flex-shrink: 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.guidance {
|
| 101 |
+
color: #e5e7eb;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.guidance h2 {
|
| 105 |
+
color: #ffffff;
|
| 106 |
+
margin-top: 0;
|
| 107 |
+
font-size: 1.3em;
|
| 108 |
+
border-bottom: 2px solid #60a5fa;
|
| 109 |
+
padding-bottom: 8px;
|
| 110 |
+
margin-bottom: 20px;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.gui-container h2{
|
| 114 |
+
color: #ffffff;
|
| 115 |
+
margin-top: 0;
|
| 116 |
+
font-size: 1.3em;
|
| 117 |
+
border-bottom: 2px solid #60fae5;
|
| 118 |
+
padding-bottom: 8px;
|
| 119 |
+
margin-bottom: 20px;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.step {
|
| 123 |
+
margin: 12px 0;
|
| 124 |
+
padding: 12px;
|
| 125 |
+
background: rgba(96, 165, 250, 0.1);
|
| 126 |
+
border-radius: 6px;
|
| 127 |
+
border-left: 3px solid #60a5fa;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
.step h3 {
|
| 131 |
+
margin: 0 0 8px 0;
|
| 132 |
+
color: #ffffff;
|
| 133 |
+
font-size: 1em;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.step p {
|
| 137 |
+
margin: 4px 0;
|
| 138 |
+
line-height: 1.4;
|
| 139 |
+
font-size: 0.85em;
|
| 140 |
+
color: #d1d5db;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
.controls-info {
|
| 144 |
+
background: rgba(168, 85, 247, 0.1);
|
| 145 |
+
border-left: 3px solid #a855f7;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.keyboard-shortcuts {
|
| 149 |
+
background: rgba(34, 197, 94, 0.1);
|
| 150 |
+
border-left: 3px solid #22c55e;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.loading {
|
| 154 |
+
position: absolute;
|
| 155 |
+
top: 50%;
|
| 156 |
+
left: 50%;
|
| 157 |
+
min-width: 300px;
|
| 158 |
+
min-height: 200px;
|
| 159 |
+
transform: translate(-50%, -50%);
|
| 160 |
+
background: rgba(0, 0, 0, 0.9);
|
| 161 |
+
color: white;
|
| 162 |
+
padding: 20px;
|
| 163 |
+
border-radius: 10px;
|
| 164 |
+
display: none;
|
| 165 |
+
z-index: 1000;
|
| 166 |
+
text-align: center;
|
| 167 |
+
vertical-align: middle;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
.generation-info {
|
| 171 |
+
background: rgba(34, 197, 94, 0.1);
|
| 172 |
+
border: 1px solid #22c55e;
|
| 173 |
+
border-radius: 8px;
|
| 174 |
+
padding: 15px;
|
| 175 |
+
margin: 10px 0;
|
| 176 |
+
color: #22c55e;
|
| 177 |
+
font-family: 'Courier New', monospace;
|
| 178 |
+
font-size: 0.9em;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
.progress-container {
|
| 182 |
+
width: 100%;
|
| 183 |
+
background: rgba(255, 255, 255, 0.1);
|
| 184 |
+
border-radius: 10px;
|
| 185 |
+
overflow: hidden;
|
| 186 |
+
margin: 10px 0;
|
| 187 |
+
position: relative;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
.progress-bar {
|
| 191 |
+
height: 20px;
|
| 192 |
+
background: linear-gradient(90deg, #60a5fa, #3b82f6);
|
| 193 |
+
width: 0%;
|
| 194 |
+
transition: width 0.3s ease;
|
| 195 |
+
border-radius: 10px;
|
| 196 |
+
position: relative;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
.progress-text {
|
| 200 |
+
position: absolute;
|
| 201 |
+
top: 50%;
|
| 202 |
+
left: 50%;
|
| 203 |
+
transform: translate(-50%, -50%);
|
| 204 |
+
color: white;
|
| 205 |
+
font-weight: bold;
|
| 206 |
+
font-size: 0.8em;
|
| 207 |
+
white-space: nowrap;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/* Info tooltip */
|
| 211 |
+
.info-tip {
|
| 212 |
+
display: inline-block;
|
| 213 |
+
position: relative;
|
| 214 |
+
margin-left: 8px;
|
| 215 |
+
width: 16px;
|
| 216 |
+
height: 16px;
|
| 217 |
+
line-height: 16px;
|
| 218 |
+
text-align: center;
|
| 219 |
+
border-radius: 50%;
|
| 220 |
+
background: #3b82f6;
|
| 221 |
+
color: #fff;
|
| 222 |
+
font-size: 12px;
|
| 223 |
+
cursor: default;
|
| 224 |
+
user-select: none;
|
| 225 |
+
}
|
| 226 |
+
.info-tip .tooltip {
|
| 227 |
+
display: none;
|
| 228 |
+
position: absolute;
|
| 229 |
+
left: 0;
|
| 230 |
+
top: calc(100% + 8px); /* show below the icon */
|
| 231 |
+
transform: none;
|
| 232 |
+
background: rgba(0,0,0,0.9);
|
| 233 |
+
color: #e5e7eb;
|
| 234 |
+
border: 1px solid rgba(255,255,255,0.15);
|
| 235 |
+
border-radius: 8px;
|
| 236 |
+
padding: 10px 12px;
|
| 237 |
+
font-size: 12px;
|
| 238 |
+
width: 360px; /* wider tooltip */
|
| 239 |
+
white-space: normal;
|
| 240 |
+
z-index: 2000; /* above GUI and other elements */
|
| 241 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.4);
|
| 242 |
+
}
|
| 243 |
+
.info-tip:hover .tooltip {
|
| 244 |
+
display: block;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.status-bar {
|
| 248 |
+
background: rgba(0, 0, 0, 0.9);
|
| 249 |
+
color: #60a5fa;
|
| 250 |
+
padding: 8px 15px;
|
| 251 |
+
font-family: 'Courier New', monospace;
|
| 252 |
+
font-size: 0.8em;
|
| 253 |
+
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
| 254 |
+
flex-shrink: 0;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
.canvas-container {
|
| 258 |
+
width: 100%;
|
| 259 |
+
height: 100%;
|
| 260 |
+
display: flex;
|
| 261 |
+
justify-content: center;
|
| 262 |
+
align-items: center;
|
| 263 |
+
background:
|
| 264 |
+
repeating-linear-gradient(
|
| 265 |
+
45deg,
|
| 266 |
+
#1a1a1a 0px,
|
| 267 |
+
#1a1a1a 10px,
|
| 268 |
+
#2a2a2a 10px,
|
| 269 |
+
#2a2a2a 20px
|
| 270 |
+
);
|
| 271 |
+
position: relative;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
.canvas-wrapper {
|
| 275 |
+
position: relative;
|
| 276 |
+
border: 2px solid #444;
|
| 277 |
+
background: #111;
|
| 278 |
+
box-shadow:
|
| 279 |
+
0 0 20px rgba(0, 0, 0, 0.5),
|
| 280 |
+
inset 0 0 10px rgba(0, 0, 0, 0.3);
|
| 281 |
+
border-radius: 4px;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.canvas-wrapper canvas {
|
| 285 |
+
display: block;
|
| 286 |
+
border-radius: 2px;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/* Add a subtle animation to the canvas wrapper */
|
| 290 |
+
.canvas-wrapper:hover {
|
| 291 |
+
border-color: #666;
|
| 292 |
+
box-shadow:
|
| 293 |
+
0 0 30px rgba(0, 0, 0, 0.7),
|
| 294 |
+
inset 0 0 15px rgba(0, 0, 0, 0.4);
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/* Progress & status beautify */
|
| 298 |
+
.progress-container {
|
| 299 |
+
width: 100%;
|
| 300 |
+
height: 18px;
|
| 301 |
+
background: linear-gradient(180deg, rgba(255,255,255,0.06), rgba(255,255,255,0.02));
|
| 302 |
+
border: 1px solid rgba(255,255,255,0.12);
|
| 303 |
+
border-radius: 999px;
|
| 304 |
+
overflow: hidden;
|
| 305 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.35) inset;
|
| 306 |
+
position: relative;
|
| 307 |
+
}
|
| 308 |
+
.progress-bar {
|
| 309 |
+
height: 100%;
|
| 310 |
+
background: linear-gradient(90deg, #60a5fa, #8b5cf6);
|
| 311 |
+
box-shadow: 0 0 10px rgba(96,165,250,0.65);
|
| 312 |
+
position: relative;
|
| 313 |
+
transition: width .15s ease;
|
| 314 |
+
}
|
| 315 |
+
.progress-text {
|
| 316 |
+
position: absolute;
|
| 317 |
+
top: 50%;
|
| 318 |
+
left: 50%;
|
| 319 |
+
transform: translate(-50%, -50%);
|
| 320 |
+
font-size: 11px;
|
| 321 |
+
color: #f8fafc;
|
| 322 |
+
text-shadow: 0 1px 2px rgba(0,0,0,0.5);
|
| 323 |
+
pointer-events: none;
|
| 324 |
+
white-space: nowrap;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
.status-badges {
|
| 328 |
+
display: flex;
|
| 329 |
+
gap: 8px;
|
| 330 |
+
flex-wrap: wrap;
|
| 331 |
+
margin-top: 8px;
|
| 332 |
+
}
|
| 333 |
+
.badge {
|
| 334 |
+
display: inline-flex;
|
| 335 |
+
align-items: center;
|
| 336 |
+
gap: 6px;
|
| 337 |
+
padding: 6px 10px;
|
| 338 |
+
border-radius: 8px;
|
| 339 |
+
font-size: 12px;
|
| 340 |
+
border: 1px solid rgba(255,255,255,0.12);
|
| 341 |
+
background: rgba(255,255,255,0.06);
|
| 342 |
+
}
|
| 343 |
+
.badge .dot { width: 8px; height: 8px; border-radius: 999px; }
|
| 344 |
+
.badge.queue .dot { background: #f59e0b; }
|
| 345 |
+
.badge.running .dot { background: #22c55e; }
|
| 346 |
+
.badge.time .dot { background: #60a5fa; }
|
| 347 |
+
.badge.bytes .dot { background: #a78bfa; }
|
| 348 |
+
|
| 349 |
+
.details-grid {
|
| 350 |
+
display: grid;
|
| 351 |
+
grid-template-columns: repeat(2, minmax(0, 1fr));
|
| 352 |
+
gap: 6px 12px;
|
| 353 |
+
margin-top: 8px;
|
| 354 |
+
font-size: 12px;
|
| 355 |
+
color: #cbd5e1;
|
| 356 |
+
}
|
| 357 |
+
.details-grid div { opacity: 0.9; }
|
| 358 |
+
|
| 359 |
+
/* Canvas resizing indicator */
|
| 360 |
+
.canvas-wrapper.resizing {
|
| 361 |
+
border-color: #60a5fa;
|
| 362 |
+
box-shadow:
|
| 363 |
+
0 0 25px rgba(96, 165, 250, 0.3),
|
| 364 |
+
inset 0 0 10px rgba(96, 165, 250, 0.1);
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
.canvas-wrapper.resizing::after {
|
| 368 |
+
content: "Resizing...";
|
| 369 |
+
position: absolute;
|
| 370 |
+
top: 50%;
|
| 371 |
+
left: 50%;
|
| 372 |
+
transform: translate(-50%, -50%);
|
| 373 |
+
color: #60a5fa;
|
| 374 |
+
font-size: 12px;
|
| 375 |
+
font-weight: bold;
|
| 376 |
+
z-index: 10;
|
| 377 |
+
pointer-events: none;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
/* GUI Panel Styling */
|
| 381 |
+
.gui-panel {
|
| 382 |
+
background: rgba(0, 0, 0, 0.8);
|
| 383 |
+
border-radius: 8px;
|
| 384 |
+
padding: 15px;
|
| 385 |
+
min-height: 400px;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
.gui-panel .lil-gui {
|
| 389 |
+
--background-color: rgba(0, 0, 0, 0.8);
|
| 390 |
+
--text-color: #ffffff;
|
| 391 |
+
--title-background-color: rgba(96, 165, 250, 0.2);
|
| 392 |
+
--title-text-color: #ffffff;
|
| 393 |
+
--widget-color: rgba(96, 165, 250, 0.3);
|
| 394 |
+
--hover-color: rgba(96, 165, 250, 0.5);
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
/* Ensure GUI is visible */
|
| 398 |
+
.lil-gui {
|
| 399 |
+
position: relative !important;
|
| 400 |
+
z-index: 1000 !important;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
@media (max-width: 1200px) {
|
| 404 |
+
.left-panel {
|
| 405 |
+
width: 250px;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
.right-panel {
|
| 409 |
+
width: 280px;
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
@media (max-width: 768px) {
|
| 414 |
+
.content-container {
|
| 415 |
+
flex-direction: column;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
.left-panel, .right-panel {
|
| 419 |
+
width: 100%;
|
| 420 |
+
height: auto;
|
| 421 |
+
max-height: 200px;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
.center-panel {
|
| 425 |
+
flex: 1;
|
| 426 |
+
min-height: 400px;
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
</style>
|
| 430 |
+
<script type="importmap">
|
| 431 |
+
{
|
| 432 |
+
"imports": {
|
| 433 |
+
"three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.174.0/three.module.js",
|
| 434 |
+
"@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.6/spark.module.js",
|
| 435 |
+
"lil-gui": "https://cdn.jsdelivr.net/npm/lil-gui@0.20/+esm"
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
</script>
|
| 439 |
+
</head>
|
| 440 |
+
<body>
|
| 441 |
+
<div class="main-container">
|
| 442 |
+
<!-- Header Section -->
|
| 443 |
+
<header class="header">
|
| 444 |
+
<div style="display: flex; justify-content: space-between; align-items: center; width: 100%;">
|
| 445 |
+
<h1 style="margin: 0; flex: 1; text-align: left;">
|
| 446 |
+
<span class="header-title-wrap">FlashWorld Spark Demo
|
| 447 |
+
<span class="info-tip">!
|
| 448 |
+
<span class="tooltip" style="max-width: 260px; text-align: left;">Note: Front-end real-time rend ering in Spark uses compressed Gaussian Splat attributes. Visual quality in this demo may be lower than offline/back-end rendering.
|
| 449 |
+
Also, the generation is fast but the downloading may be slow, please be patient.
|
| 450 |
+
</span>
|
| 451 |
+
</span>
|
| 452 |
+
</span>
|
| 453 |
+
</h1>
|
| 454 |
+
<div class="header-links" style="margin-left: 20px;">
|
| 455 |
+
<a href="#" target="_blank">Paper</a>
|
| 456 |
+
<a href="#" target="_blank">Code</a>
|
| 457 |
+
<a href="#" target="_blank">Project Page</a>
|
| 458 |
+
</div>
|
| 459 |
+
</div>
|
| 460 |
+
</header>
|
| 461 |
+
|
| 462 |
+
<!-- Main Content Container -->
|
| 463 |
+
<div class="content-container">
|
| 464 |
+
<!-- Left Panel: Simplified Guidance -->
|
| 465 |
+
<div class="left-panel">
|
| 466 |
+
<div class="guidance">
|
| 467 |
+
<h2>Instructions</h2>
|
| 468 |
+
|
| 469 |
+
<div class="step">
|
| 470 |
+
<h3>1. Configure</h3>
|
| 471 |
+
<p>Set FOV and Resolution and Click "Fix Configurations"</p>
|
| 472 |
+
</div>
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
<div class="step">
|
| 476 |
+
<h3>2. Set Camera Trajectory</h3>
|
| 477 |
+
<p><b>Manual:</b> Navigate with mouse and keyboard, press <kbd>Space</kbd> to record</p>
|
| 478 |
+
<p><b>Template:</b> Select template type and click "Generate Trajectory"</p>
|
| 479 |
+
<p><b>JSON:</b> Load trajectory from JSON file</p>
|
| 480 |
+
</div>
|
| 481 |
+
|
| 482 |
+
<div class="step">
|
| 483 |
+
<h3>3. Add Prompts</h3>
|
| 484 |
+
<p>Upload image or enter text description</p>
|
| 485 |
+
</div>
|
| 486 |
+
|
| 487 |
+
<div class="step">
|
| 488 |
+
<h3>4. Generate</h3>
|
| 489 |
+
<p>Click "Generate!" to create your scene</p>
|
| 490 |
+
</div>
|
| 491 |
+
|
| 492 |
+
<div class="step controls-info">
|
| 493 |
+
<h3>Controls</h3>
|
| 494 |
+
<p><strong>Mouse/QE:</strong> Rotate view</p>
|
| 495 |
+
<p><strong>WASD/RF:</strong> Move</p>
|
| 496 |
+
<p><strong>Space:</strong> Record camera</p>
|
| 497 |
+
</div>
|
| 498 |
+
|
| 499 |
+
</div>
|
| 500 |
+
</div>
|
| 501 |
+
|
| 502 |
+
<!-- Center Panel: Canvas -->
|
| 503 |
+
<div class="center-panel">
|
| 504 |
+
<div class="canvas-container" id="canvas-container">
|
| 505 |
+
<div class="canvas-wrapper" id="canvas-wrapper">
|
| 506 |
+
<div class="loading" id="loading">
|
| 507 |
+
<h3>🎬 Generating Scene...</h3>
|
| 508 |
+
<p>Please wait while we create your 3D scene</p>
|
| 509 |
+
<div id="generation-info" class="generation-info" style="display: none;">
|
| 510 |
+
<div><strong>Generation Time:</strong> <span id="generation-time">-</span> seconds</div>
|
| 511 |
+
<div><strong>File Size:</strong> <span id="file-size">-</span> MB</div>
|
| 512 |
+
</div>
|
| 513 |
+
<div id="download-progress" style="display: none;">
|
| 514 |
+
<div class="progress-container">
|
| 515 |
+
<div class="progress-bar" id="progress-bar"></div>
|
| 516 |
+
<div class="progress-text" id="progress-text">0%</div>
|
| 517 |
+
</div>
|
| 518 |
+
<div class="status-badges" id="status-badges" style="display: none;">
|
| 519 |
+
<div class="badge queue" id="badge-queue"><span class="dot"></span><span id="badge-queue-text">Queue</span></div>
|
| 520 |
+
<div class="badge running" id="badge-running" style="display: none;"><span class="dot"></span><span id="badge-running-text">Running</span></div>
|
| 521 |
+
<div class="badge time" id="badge-time" style="display: none;"><span class="dot"></span><span id="badge-time-text">00:00</span></div>
|
| 522 |
+
</div>
|
| 523 |
+
<div id="queue-details" class="details-grid" style="display: none;"></div>
|
| 524 |
+
<div id="download-details" class="details-grid" style="display: none;"></div>
|
| 525 |
+
</div>
|
| 526 |
+
</div>
|
| 527 |
+
</div>
|
| 528 |
+
</div>
|
| 529 |
+
</div>
|
| 530 |
+
|
| 531 |
+
<!-- Right Panel: GUI -->
|
| 532 |
+
<div class="right-panel">
|
| 533 |
+
<div class="gui-container">
|
| 534 |
+
<!-- <h2>GUI</h2> -->
|
| 535 |
+
<div class="gui-panel" id="gui-container">
|
| 536 |
+
<!-- GUI will be inserted here -->
|
| 537 |
+
</div>
|
| 538 |
+
</div>
|
| 539 |
+
|
| 540 |
+
<!-- Image Preview Area -->
|
| 541 |
+
<div id="image-preview-area" style="padding: 10px; display: none;">
|
| 542 |
+
<div style="font-size: 12px; color: #ccc; margin-bottom: 8px; text-align: left;">Input Image Preview</div>
|
| 543 |
+
<div style="text-align: center;">
|
| 544 |
+
<img id="preview-img" style="max-width: 100%; max-height: 200px; border-radius: 4px; box-shadow: 0 2px 8px rgba(0,0,0,0.3);" />
|
| 545 |
+
</div>
|
| 546 |
+
</div>
|
| 547 |
+
</div>
|
| 548 |
+
</div>
|
| 549 |
+
|
| 550 |
+
<!-- Status Bar -->
|
| 551 |
+
<div class="status-bar" id="status-bar">
|
| 552 |
+
Ready to generate 3D scenes | Cameras: 0 | Status: Waiting for input
|
| 553 |
+
</div>
|
| 554 |
+
</div>
|
| 555 |
+
|
| 556 |
+
<!-- Hidden File Inputs -->
|
| 557 |
+
<input id="file-input" type="file" accept=".jpg,.png,.jpeg" multiple="true" style="display: none;" />
|
| 558 |
+
<input id="json-input" type="file" accept=".json" multiple="false" style="display: none;" />
|
| 559 |
+
|
| 560 |
+
<script type="module">
|
| 561 |
+
// =========================
|
| 562 |
+
// Imports & Global Variables
|
| 563 |
+
// =========================
|
| 564 |
+
import * as THREE from "three";
|
| 565 |
+
import { SplatMesh, SparkControls, textSplats } from "@sparkjsdev/spark";
|
| 566 |
+
import GUI from "lil-gui";
|
| 567 |
+
|
| 568 |
+
// Scene, Camera, Renderer, Controls
|
| 569 |
+
const scene = new THREE.Scene();
|
| 570 |
+
const camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 1000);
|
| 571 |
+
camera.position.set(0, 0, 1.5);
|
| 572 |
+
const renderer = new THREE.WebGLRenderer();
|
| 573 |
+
renderer.setSize(window.innerWidth, window.innerHeight);
|
| 574 |
+
|
| 575 |
+
// Wait for DOM to be ready
|
| 576 |
+
function initializeRenderer() {
|
| 577 |
+
const canvasWrapper = document.getElementById('canvas-wrapper');
|
| 578 |
+
if (canvasWrapper) {
|
| 579 |
+
canvasWrapper.appendChild(renderer.domElement);
|
| 580 |
+
|
| 581 |
+
// Set initial canvas size based on current resolution
|
| 582 |
+
updateCanvasSize();
|
| 583 |
+
console.log('Canvas initialized in wrapper');
|
| 584 |
+
} else {
|
| 585 |
+
console.error('Canvas wrapper not found');
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
// Update canvas size based on selected resolution
|
| 590 |
+
function updateCanvasSize() {
|
| 591 |
+
const canvasWrapper = document.getElementById('canvas-wrapper');
|
| 592 |
+
if (!canvasWrapper) return;
|
| 593 |
+
|
| 594 |
+
// Show resizing indicator
|
| 595 |
+
canvasWrapper.classList.add('resizing');
|
| 596 |
+
|
| 597 |
+
// Get current resolution from GUI options
|
| 598 |
+
const resolution = guiOptions.Resolution.split('x');
|
| 599 |
+
const width = parseInt(resolution[2]) || 704; // W
|
| 600 |
+
const height = parseInt(resolution[1]) || 480; // H
|
| 601 |
+
|
| 602 |
+
// Set canvas size
|
| 603 |
+
renderer.setSize(width, height);
|
| 604 |
+
camera.aspect = width / height;
|
| 605 |
+
camera.updateProjectionMatrix();
|
| 606 |
+
|
| 607 |
+
// Update wrapper size to match canvas
|
| 608 |
+
canvasWrapper.style.width = width + 'px';
|
| 609 |
+
canvasWrapper.style.height = height + 'px';
|
| 610 |
+
|
| 611 |
+
// Remove resizing indicator after a short delay
|
| 612 |
+
setTimeout(() => {
|
| 613 |
+
canvasWrapper.classList.remove('resizing');
|
| 614 |
+
}, 300);
|
| 615 |
+
|
| 616 |
+
console.log('Canvas size updated:', width, 'x', height);
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
const controls = new SparkControls({ canvas: renderer.domElement });
|
| 620 |
+
|
| 621 |
+
// Camera splats and params
|
| 622 |
+
const cameraSplats = [];
|
| 623 |
+
const cameraParams = [];
|
| 624 |
+
const interpolatedCamerasSplats = [];
|
| 625 |
+
|
| 626 |
+
// State
|
| 627 |
+
let fixGenerationFOV = false;
|
| 628 |
+
let inputImageBase64 = null;
|
| 629 |
+
let inputImageResolution = null;
|
| 630 |
+
let currentGeneratedSplat = null; // 跟踪当前生成的场景
|
| 631 |
+
|
| 632 |
+
// UI Elements
|
| 633 |
+
const loadingElement = document.getElementById('loading');
|
| 634 |
+
const statusBar = document.getElementById('status-bar');
|
| 635 |
+
|
| 636 |
+
// GUI variable - declare early
|
| 637 |
+
let gui = null;
|
| 638 |
+
|
| 639 |
+
// Status update function
|
| 640 |
+
function updateStatus(message, cameraCount = null) {
|
| 641 |
+
const cameraText = cameraCount !== null ? `Cameras: ${cameraCount}` : `Cameras: ${cameraParams.length}`;
|
| 642 |
+
statusBar.textContent = `${message} | ${cameraText} | Status: ${fixGenerationFOV ? 'Ready to record' : 'Configure settings'}`;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
// Show/hide loading
|
| 646 |
+
function showLoading(show) {
|
| 647 |
+
loadingElement.style.display = show ? 'block' : 'none';
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
// Show generation info
|
| 651 |
+
function showGenerationInfo(generationTime, fileSize) {
|
| 652 |
+
const generationInfo = document.getElementById('generation-info');
|
| 653 |
+
const generationTimeElement = document.getElementById('generation-time');
|
| 654 |
+
const fileSizeElement = document.getElementById('file-size');
|
| 655 |
+
|
| 656 |
+
generationTimeElement.textContent = generationTime.toFixed(2);
|
| 657 |
+
fileSizeElement.textContent = (fileSize / (1024 * 1024)).toFixed(2);
|
| 658 |
+
generationInfo.style.display = 'block';
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// Show download progress
|
| 662 |
+
function showDownloadProgress() {
|
| 663 |
+
const downloadProgress = document.getElementById('download-progress');
|
| 664 |
+
downloadProgress.style.display = 'block';
|
| 665 |
+
const qd = document.getElementById('queue-details');
|
| 666 |
+
const dd = document.getElementById('download-details');
|
| 667 |
+
const badges = document.getElementById('status-badges');
|
| 668 |
+
if (qd) qd.style.display = 'none';
|
| 669 |
+
if (dd) dd.style.display = 'none';
|
| 670 |
+
if (badges) badges.style.display = 'none';
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
// Update progress bar
|
| 674 |
+
function updateProgressBar(percentage) {
|
| 675 |
+
const progressBar = document.getElementById('progress-bar');
|
| 676 |
+
const progressText = document.getElementById('progress-text');
|
| 677 |
+
|
| 678 |
+
progressBar.style.width = percentage + '%';
|
| 679 |
+
progressText.textContent = `${Math.round(percentage)}%`;
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
// Update progress label text (stage indicator)
|
| 683 |
+
function setProgressLabel(text) {
|
| 684 |
+
const progressText = document.getElementById('progress-text');
|
| 685 |
+
if (progressText) progressText.textContent = text;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
// ==============
|
| 689 |
+
// Queue handling
|
| 690 |
+
// ==============
|
| 691 |
+
let queuePollTimer = null;
|
| 692 |
+
let currentTaskId = null;
|
| 693 |
+
let initialQueuePosition = null;
|
| 694 |
+
let latestGenerationTime = null;
|
| 695 |
+
let lastDownloadPct = 0;
|
| 696 |
+
let lastDownloadUpdateTs = 0;
|
| 697 |
+
|
| 698 |
+
function showQueueWaiting(position, runningCount, queuedCount) {
|
| 699 |
+
// Use only the progress bar to show queue progress (from initial position to 0)
|
| 700 |
+
showDownloadProgress();
|
| 701 |
+
if (initialQueuePosition === null) {
|
| 702 |
+
// Initialize from first seen position; ensure >= 1 so 0 -> 100%
|
| 703 |
+
const initPos = (typeof position === 'number') ? position : 0;
|
| 704 |
+
initialQueuePosition = Math.max(initPos, 1);
|
| 705 |
+
}
|
| 706 |
+
const percent = initialQueuePosition && initialQueuePosition > 0
|
| 707 |
+
? Math.max(0, Math.min(100, ((initialQueuePosition - (position || 0)) / initialQueuePosition) * 100))
|
| 708 |
+
: 0;
|
| 709 |
+
updateProgressBar(percent);
|
| 710 |
+
const totalWaiting = (position || 0) + (queuedCount || 0);
|
| 711 |
+
if (position !== null && position !== undefined) {
|
| 712 |
+
const pctText = `${Math.round(percent)}%`;
|
| 713 |
+
if (totalWaiting > 0) {
|
| 714 |
+
setProgressLabel(`Queued ${position}/${totalWaiting} (${pctText})`);
|
| 715 |
+
} else {
|
| 716 |
+
setProgressLabel(`Queued ${position} (${pctText})`);
|
| 717 |
+
}
|
| 718 |
+
} else {
|
| 719 |
+
setProgressLabel('Queued');
|
| 720 |
+
}
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
async function pollTaskUntilReady(taskId) {
|
| 724 |
+
currentTaskId = taskId;
|
| 725 |
+
initialQueuePosition = null;
|
| 726 |
+
if (queuePollTimer) {
|
| 727 |
+
clearInterval(queuePollTimer);
|
| 728 |
+
queuePollTimer = null;
|
| 729 |
+
}
|
| 730 |
+
const queueStartTs = Date.now();
|
| 731 |
+
|
| 732 |
+
const pollOnce = async () => {
|
| 733 |
+
try {
|
| 734 |
+
const resp = await fetch(`${guiOptions.BackendAddress}/task/${taskId}`);
|
| 735 |
+
if (!resp.ok) return;
|
| 736 |
+
const info = await resp.json();
|
| 737 |
+
if (!info || !info.success) return;
|
| 738 |
+
|
| 739 |
+
const pos = info.queue && typeof info.queue.position === 'number' ? info.queue.position : 0;
|
| 740 |
+
const running = info.queue ? info.queue.running_count : 0;
|
| 741 |
+
const queued = info.queue ? info.queue.queued_count : 0;
|
| 742 |
+
if (info.status === 'queued' || info.status === 'running') {
|
| 743 |
+
// Only progress bar; set stage label
|
| 744 |
+
if (info.status === 'queued') {
|
| 745 |
+
showQueueWaiting(pos, running, queued);
|
| 746 |
+
} else {
|
| 747 |
+
// Transitioned to running: finalize queue progress visually
|
| 748 |
+
updateProgressBar(100);
|
| 749 |
+
showDownloadProgress();
|
| 750 |
+
setProgressLabel('Generating...');
|
| 751 |
+
}
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
if (info.status === 'completed' && info.download_url) {
|
| 755 |
+
clearInterval(queuePollTimer);
|
| 756 |
+
queuePollTimer = null;
|
| 757 |
+
latestGenerationTime = typeof info.generation_time === 'number' ? info.generation_time : null;
|
| 758 |
+
// Proceed to download the generated file like the normal path
|
| 759 |
+
updateStatus('Downloading generated scene...', cameraParams.length);
|
| 760 |
+
const response = await fetch(guiOptions.BackendAddress + info.download_url);
|
| 761 |
+
if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
|
| 762 |
+
const contentLength = response.headers.get('content-length');
|
| 763 |
+
const total = parseInt(contentLength || '0', 10);
|
| 764 |
+
// Show generation info immediately once we know it and total size from headers
|
| 765 |
+
showGenerationInfo(latestGenerationTime || 0, total);
|
| 766 |
+
let loaded = 0;
|
| 767 |
+
const reader = response.body.getReader();
|
| 768 |
+
const chunks = [];
|
| 769 |
+
updateProgressBar(0);
|
| 770 |
+
setProgressLabel('Downloading 0%');
|
| 771 |
+
lastDownloadPct = 0;
|
| 772 |
+
lastDownloadUpdateTs = 0;
|
| 773 |
+
while (true) {
|
| 774 |
+
const { done, value } = await reader.read();
|
| 775 |
+
if (done) break;
|
| 776 |
+
chunks.push(value);
|
| 777 |
+
loaded += value.length;
|
| 778 |
+
if (total) {
|
| 779 |
+
const pct = Math.min(100, (loaded / total) * 100);
|
| 780 |
+
const now = Date.now();
|
| 781 |
+
const rounded = Math.round(pct);
|
| 782 |
+
// Throttle and enforce monotonic increase
|
| 783 |
+
if (rounded > Math.round(lastDownloadPct) || (now - lastDownloadUpdateTs) > 200) {
|
| 784 |
+
lastDownloadPct = Math.max(lastDownloadPct, pct);
|
| 785 |
+
updateProgressBar(lastDownloadPct);
|
| 786 |
+
setProgressLabel(`Downloading ${Math.round(lastDownloadPct)}%`);
|
| 787 |
+
lastDownloadUpdateTs = now;
|
| 788 |
+
}
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
if (instructionSplat) {
|
| 793 |
+
scene.remove(instructionSplat);
|
| 794 |
+
console.log('Instruction splat removed');
|
| 795 |
+
instructionSplat = null;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
const blob = new Blob(chunks);
|
| 799 |
+
const url = URL.createObjectURL(blob);
|
| 800 |
+
// Continue to load the splat
|
| 801 |
+
updateStatus('Loading generated scene...', cameraParams.length);
|
| 802 |
+
|
| 803 |
+
const GeneratedSplat = new SplatMesh({ url });
|
| 804 |
+
scene.add(GeneratedSplat);
|
| 805 |
+
currentGeneratedSplat = GeneratedSplat;
|
| 806 |
+
updateStatus('Scene generated successfully!', cameraParams.length);
|
| 807 |
+
// Show generation time and total file size (MB)
|
| 808 |
+
showGenerationInfo(latestGenerationTime || 0, total || blob.size);
|
| 809 |
+
// Notify backend to delete the server file after client has downloaded it
|
| 810 |
+
try {
|
| 811 |
+
if (info.file_id) {
|
| 812 |
+
const resp = await fetch(`${guiOptions.BackendAddress}/delete/${info.file_id}`, { method: 'POST' });
|
| 813 |
+
if (!resp.ok) console.warn('Delete notify failed');
|
| 814 |
+
}
|
| 815 |
+
} catch (e) {
|
| 816 |
+
console.warn('Delete notify error', e);
|
| 817 |
+
}
|
| 818 |
+
hideDownloadProgress();
|
| 819 |
+
showLoading(false);
|
| 820 |
+
} else if (info.status === 'failed') {
|
| 821 |
+
clearInterval(queuePollTimer);
|
| 822 |
+
queuePollTimer = null;
|
| 823 |
+
throw new Error(info.error || 'Generation failed');
|
| 824 |
+
}
|
| 825 |
+
} catch (e) {
|
| 826 |
+
console.debug('Polling error:', e);
|
| 827 |
+
}
|
| 828 |
+
};
|
| 829 |
+
|
| 830 |
+
await pollOnce();
|
| 831 |
+
queuePollTimer = setInterval(pollOnce, 2000);
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
// Hide download progress
|
| 835 |
+
function hideDownloadProgress() {
|
| 836 |
+
const downloadProgress = document.getElementById('download-progress');
|
| 837 |
+
downloadProgress.style.display = 'none';
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
// Playback scrubber (0..1)
|
| 841 |
+
let userCameraState = null; // 存储用户播放前的相机状态
|
| 842 |
+
|
| 843 |
+
// 根据时间比例获取插值相机
|
| 844 |
+
function getInterpolatedCameraAtTime(t) {
|
| 845 |
+
if (cameraParams.length === 0) {
|
| 846 |
+
return camera;
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
if (cameraParams.length === 1) {
|
| 850 |
+
return cameraParams[0];
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
+
// 确保t在有效范围内
|
| 854 |
+
const clampedT = Math.max(0, Math.min(1, t));
|
| 855 |
+
|
| 856 |
+
// 计算在相机序列中的位置
|
| 857 |
+
const cameraIndex = clampedT * (cameraParams.length - 1);
|
| 858 |
+
const startIndex = Math.min(Math.floor(cameraIndex), cameraParams.length - 2);
|
| 859 |
+
const endIndex = startIndex + 1;
|
| 860 |
+
const startCamera = cameraParams[startIndex];
|
| 861 |
+
const endCamera = cameraParams[endIndex];
|
| 862 |
+
|
| 863 |
+
// 计算两个相机之间的插值比例
|
| 864 |
+
const _t = cameraIndex - startIndex;
|
| 865 |
+
|
| 866 |
+
// 使用interpolateTwoCameras进行插值
|
| 867 |
+
return interpolateTwoCameras(startCamera, endCamera, _t);
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
function setCameraByScrub(t) {
|
| 871 |
+
if (cameraParams.length === 0) return;
|
| 872 |
+
const clampedT = Math.max(0, Math.min(1, t));
|
| 873 |
+
const camT = getInterpolatedCameraAtTime(clampedT);
|
| 874 |
+
camera.position.copy(camT.position);
|
| 875 |
+
camera.quaternion.copy(camT.quaternion);
|
| 876 |
+
camera.fov = camT.fov;
|
| 877 |
+
camera.updateProjectionMatrix();
|
| 878 |
+
}
|
| 879 |
+
|
| 880 |
+
// Supported resolutions
|
| 881 |
+
const supportedResolutions = [
|
| 882 |
+
{ frame: 24, width: 704, height: 480 },
|
| 883 |
+
{ frame: 24, width: 480, height: 704 }
|
| 884 |
+
];
|
| 885 |
+
|
| 886 |
+
// GUI Options - declare early
|
| 887 |
+
const guiOptions = {
|
| 888 |
+
// 后端地址,默认为本页面ip
|
| 889 |
+
BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
|
| 890 |
+
FOV: 60,
|
| 891 |
+
LoadFromJson: () => {
|
| 892 |
+
const jsonInput = document.querySelector("#json-input");
|
| 893 |
+
if (jsonInput) jsonInput.click();
|
| 894 |
+
},
|
| 895 |
+
LoadTrajectoryFromJson: () => {
|
| 896 |
+
if (!fixGenerationFOV) {
|
| 897 |
+
updateStatus('Warning: Please fix configuration first before loading trajectory', cameraParams.length);
|
| 898 |
+
return;
|
| 899 |
+
}
|
| 900 |
+
// 设置标志,表示只加载轨迹
|
| 901 |
+
window.loadTrajectoryOnly = true;
|
| 902 |
+
const jsonInput = document.querySelector("#json-input");
|
| 903 |
+
if (jsonInput) jsonInput.click();
|
| 904 |
+
},
|
| 905 |
+
fixGenerationFOV: () => {
|
| 906 |
+
// These controllers will be set when GUI is initialized
|
| 907 |
+
if (window.fixGenerationFOVController) window.fixGenerationFOVController.disable();
|
| 908 |
+
fixGenerationFOV = true;
|
| 909 |
+
|
| 910 |
+
const new_camera = new THREE.PerspectiveCamera(guiOptions.FOV, guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1]);
|
| 911 |
+
new_camera.position.set(0, 0, 0);
|
| 912 |
+
new_camera.quaternion.set(0, 0, 0, 1);
|
| 913 |
+
new_camera.updateProjectionMatrix();
|
| 914 |
+
|
| 915 |
+
const cameraSplat = createCameraSplat(new_camera);
|
| 916 |
+
cameraSplats.push(cameraSplat);
|
| 917 |
+
cameraParams.push({
|
| 918 |
+
position: new_camera.position.clone(),
|
| 919 |
+
quaternion: new_camera.quaternion.clone(),
|
| 920 |
+
fov: new_camera.fov,
|
| 921 |
+
aspect: new_camera.aspect,
|
| 922 |
+
});
|
| 923 |
+
scene.add(cameraSplat);
|
| 924 |
+
|
| 925 |
+
updateStatus('Camera settings fixed. Press Space to record cameras.', cameraParams.length);
|
| 926 |
+
},
|
| 927 |
+
Resolution: `${supportedResolutions[0].frame}x${supportedResolutions[0].height}x${supportedResolutions[0].width}`,
|
| 928 |
+
VisualizeCameraSplats: true,
|
| 929 |
+
VisualizeInterpolatedCameras: true,
|
| 930 |
+
inputImagePrompt: () => {
|
| 931 |
+
const fileInput = document.querySelector("#file-input");
|
| 932 |
+
if (fileInput) {
|
| 933 |
+
// 仅触发选择,由全局处理程序完成裁剪与预览更新
|
| 934 |
+
fileInput.click();
|
| 935 |
+
}
|
| 936 |
+
},
|
| 937 |
+
imageIndex: 0,
|
| 938 |
+
inputTextPrompt: "",
|
| 939 |
+
|
| 940 |
+
// Camera trajectory templates
|
| 941 |
+
trajectoryMode: "Manual",
|
| 942 |
+
templateType: "Move Forward",
|
| 943 |
+
cameraTrajectory: "Manual",
|
| 944 |
+
trajectorySettings: {
|
| 945 |
+
angle: 180, // 角度 (180, 360)
|
| 946 |
+
tilt: 15 // 倾斜角 (15, 30, 45)
|
| 947 |
+
},
|
| 948 |
+
generateTrajectory: () => {
|
| 949 |
+
generateCameraTrajectory(guiOptions.templateType);
|
| 950 |
+
},
|
| 951 |
+
saveTrajectoryToJson: () => {
|
| 952 |
+
if (cameraParams.length === 0) {
|
| 953 |
+
updateStatus('No cameras to save.', cameraParams.length);
|
| 954 |
+
console.warn('No cameras to save');
|
| 955 |
+
return;
|
| 956 |
+
}
|
| 957 |
+
|
| 958 |
+
// Build JSON payload compatible with loader
|
| 959 |
+
const [nStr, hStr, wStr] = guiOptions.Resolution.split('x');
|
| 960 |
+
const n = parseInt(nStr), h = parseInt(hStr), w = parseInt(wStr);
|
| 961 |
+
const payload = {
|
| 962 |
+
// image_prompt: null,
|
| 963 |
+
// text_prompt: guiOptions.inputTextPrompt || "",
|
| 964 |
+
// image_index: guiOptions.imageIndex || 0,
|
| 965 |
+
// resolution: [n, h, w],
|
| 966 |
+
cameras: cameraParams.map(cam => ({
|
| 967 |
+
position: [cam.position.x, cam.position.y, cam.position.z],
|
| 968 |
+
quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z]
|
| 969 |
+
}))
|
| 970 |
+
};
|
| 971 |
+
|
| 972 |
+
const blob = new Blob([JSON.stringify(payload, null, 2)], { type: 'application/json' });
|
| 973 |
+
const url = URL.createObjectURL(blob);
|
| 974 |
+
const a = document.createElement('a');
|
| 975 |
+
a.href = url;
|
| 976 |
+
a.download = `trajectory_${Date.now()}.json`;
|
| 977 |
+
document.body.appendChild(a);
|
| 978 |
+
a.click();
|
| 979 |
+
document.body.removeChild(a);
|
| 980 |
+
URL.revokeObjectURL(url);
|
| 981 |
+
updateStatus('Trajectory saved to JSON.', cameraParams.length);
|
| 982 |
+
},
|
| 983 |
+
clearAllCameras: () => {
|
| 984 |
+
if (cameraParams.length <= 1) {
|
| 985 |
+
updateStatus('No cameras to clear (first camera is always preserved)', cameraParams.length);
|
| 986 |
+
return;
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
// Keep the first camera, remove all others
|
| 990 |
+
const firstCamera = cameraParams[0];
|
| 991 |
+
const firstSplat = cameraSplats[0];
|
| 992 |
+
|
| 993 |
+
// Remove all camera splats except the first one
|
| 994 |
+
for (let i = cameraSplats.length - 1; i >= 1; i--) {
|
| 995 |
+
scene.remove(cameraSplats[i]);
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
// Keep only the first camera in arrays
|
| 999 |
+
cameraSplats.length = 1;
|
| 1000 |
+
cameraParams.length = 1;
|
| 1001 |
+
|
| 1002 |
+
// Clear all interpolated camera splats from scene
|
| 1003 |
+
interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
|
| 1004 |
+
interpolatedCamerasSplats.length = 0;
|
| 1005 |
+
|
| 1006 |
+
updateStatus('Cameras cleared (first camera preserved). Ready to add more cameras.', 1);
|
| 1007 |
+
console.log('Cameras cleared, first camera preserved');
|
| 1008 |
+
},
|
| 1009 |
+
// Playback scrub value (0..1)
|
| 1010 |
+
playbackT: 0,
|
| 1011 |
+
|
| 1012 |
+
generate: () => {
|
| 1013 |
+
// 检查是否有足够的相机
|
| 1014 |
+
if (cameraParams.length < 2) {
|
| 1015 |
+
console.error('Need at least 2 cameras to generate. Please press Space to record more cameras.');
|
| 1016 |
+
updateStatus('Error: Need at least 2 cameras', cameraParams.length);
|
| 1017 |
+
return;
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
updateStatus('Preparing generation...', cameraParams.length);
|
| 1021 |
+
|
| 1022 |
+
// 删除之前生成的场景
|
| 1023 |
+
if (currentGeneratedSplat) {
|
| 1024 |
+
scene.remove(currentGeneratedSplat);
|
| 1025 |
+
currentGeneratedSplat = null;
|
| 1026 |
+
console.log('Previous generated scene removed');
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
// 初始化进度条信息
|
| 1030 |
+
const generationTimeElement = document.getElementById('generation-time');
|
| 1031 |
+
const fileSizeElement = document.getElementById('file-size');
|
| 1032 |
+
const progressBar = document.getElementById('progress-bar');
|
| 1033 |
+
const progressText = document.getElementById('progress-text');
|
| 1034 |
+
|
| 1035 |
+
if (generationTimeElement) generationTimeElement.textContent = '-';
|
| 1036 |
+
if (fileSizeElement) fileSizeElement.textContent = '-';
|
| 1037 |
+
if (progressBar) progressBar.style.width = '0%';
|
| 1038 |
+
if (progressText) progressText.textContent = '0%';
|
| 1039 |
+
|
| 1040 |
+
// 隐藏生成信息和下载进度
|
| 1041 |
+
const generationInfo = document.getElementById('generation-info');
|
| 1042 |
+
const downloadProgress = document.getElementById('download-progress');
|
| 1043 |
+
if (generationInfo) generationInfo.style.display = 'none';
|
| 1044 |
+
if (downloadProgress) downloadProgress.style.display = 'none';
|
| 1045 |
+
|
| 1046 |
+
showLoading(true);
|
| 1047 |
+
|
| 1048 |
+
// 生成插值相机并可视化
|
| 1049 |
+
const interpolatedCameras = interpolateCameras(cameraParams, parseInt(guiOptions.Resolution.split('x')[0]));
|
| 1050 |
+
interpolatedCameras.forEach(cam => {
|
| 1051 |
+
const interpolatedCameraSplat = createCameraSplat(cam, [0.5, 0.5, 0.5]);
|
| 1052 |
+
interpolatedCamerasSplats.push(interpolatedCameraSplat);
|
| 1053 |
+
scene.add(interpolatedCameraSplat);
|
| 1054 |
+
});
|
| 1055 |
+
|
| 1056 |
+
console.log('Sending request to backend...');
|
| 1057 |
+
console.log('Interpolated cameras:', interpolatedCameras.length);
|
| 1058 |
+
updateStatus('Sending request to backend...', cameraParams.length);
|
| 1059 |
+
|
| 1060 |
+
// 根据后端类型选择不同的请求方式
|
| 1061 |
+
let requestUrl, requestBody;
|
| 1062 |
+
|
| 1063 |
+
if (true) {
|
| 1064 |
+
// Flask后端:直接POST到/generate
|
| 1065 |
+
requestUrl = guiOptions.BackendAddress + '/generate';
|
| 1066 |
+
requestBody = JSON.stringify({
|
| 1067 |
+
image_prompt: inputImageBase64 ? inputImageBase64 : "",
|
| 1068 |
+
text_prompt: guiOptions.inputTextPrompt,
|
| 1069 |
+
image_index: 0,
|
| 1070 |
+
resolution: [
|
| 1071 |
+
parseInt(guiOptions.Resolution.split('x')[0]),
|
| 1072 |
+
parseInt(guiOptions.Resolution.split('x')[1]),
|
| 1073 |
+
parseInt(guiOptions.Resolution.split('x')[2])
|
| 1074 |
+
],
|
| 1075 |
+
cameras: interpolatedCameras.map(cam => ({
|
| 1076 |
+
position: [cam.position.x, cam.position.y, cam.position.z],
|
| 1077 |
+
quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z],
|
| 1078 |
+
fx: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
|
| 1079 |
+
fy: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
|
| 1080 |
+
cx: inputImageBase64 && inputImageResolution
|
| 1081 |
+
? 0.5 * inputImageResolution.width
|
| 1082 |
+
: 0.5 * parseInt(guiOptions.Resolution.split('x')[2]),
|
| 1083 |
+
cy: inputImageBase64 && inputImageResolution
|
| 1084 |
+
? 0.5 * inputImageResolution.height
|
| 1085 |
+
: 0.5 * parseInt(guiOptions.Resolution.split('x')[1]),
|
| 1086 |
+
}))
|
| 1087 |
+
});
|
| 1088 |
+
} else {
|
| 1089 |
+
|
| 1090 |
+
}
|
| 1091 |
+
|
| 1092 |
+
// 请求后端生成(异步:返回task_id并开始排队轮询)
|
| 1093 |
+
fetch(requestUrl, {
|
| 1094 |
+
method: 'POST',
|
| 1095 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1096 |
+
mode: 'cors',
|
| 1097 |
+
body: requestBody
|
| 1098 |
+
})
|
| 1099 |
+
.then(response => {
|
| 1100 |
+
const contentType = response.headers.get('content-type');
|
| 1101 |
+
if (contentType && contentType.includes('application/json')) {
|
| 1102 |
+
return response.json();
|
| 1103 |
+
} else {
|
| 1104 |
+
return response.blob().then(blob => {
|
| 1105 |
+
const url = URL.createObjectURL(blob);
|
| 1106 |
+
return { url };
|
| 1107 |
+
});
|
| 1108 |
+
}
|
| 1109 |
+
})
|
| 1110 |
+
.then(data => {
|
| 1111 |
+
console.log(data);
|
| 1112 |
+
{
|
| 1113 |
+
// 异步队列协议:后端返回 task_id + queue 信息(202)
|
| 1114 |
+
if (data && data.success && data.task_id) {
|
| 1115 |
+
updateStatus('Queued request submitted. Waiting in queue...', cameraParams.length);
|
| 1116 |
+
showQueueWaiting(data.queue?.position || 0, data.queue?.running_count || 0, data.queue?.queued_count || 0);
|
| 1117 |
+
// 轮询直到任务完成并下载
|
| 1118 |
+
return pollTaskUntilReady(data.task_id).then(() => ({ url: null }));
|
| 1119 |
+
}
|
| 1120 |
+
// 兼容旧的直接文件响应格式
|
| 1121 |
+
if (data && data.url) {
|
| 1122 |
+
updateStatus('Loading generated scene...', cameraParams.length);
|
| 1123 |
+
return Promise.resolve(data);
|
| 1124 |
+
}
|
| 1125 |
+
throw new Error('Invalid Flask response (expected task_id)');
|
| 1126 |
+
}
|
| 1127 |
+
})
|
| 1128 |
+
.then(data => {
|
| 1129 |
+
if (data.url) {
|
| 1130 |
+
updateStatus('Loading 3D scene...', cameraParams.length);
|
| 1131 |
+
// Remove the instruction splat when generation is complete
|
| 1132 |
+
if (instructionSplat) {
|
| 1133 |
+
scene.remove(instructionSplat);
|
| 1134 |
+
console.log('Instruction splat removed');
|
| 1135 |
+
}
|
| 1136 |
+
const GeneratedSplat = new SplatMesh({ url: data.url });
|
| 1137 |
+
scene.add(GeneratedSplat);
|
| 1138 |
+
currentGeneratedSplat = GeneratedSplat; // 保存新生成的场景引用
|
| 1139 |
+
console.log('3D scene loaded successfully!');
|
| 1140 |
+
updateStatus('Scene generated successfully!', cameraParams.length);
|
| 1141 |
+
hideDownloadProgress();
|
| 1142 |
+
showLoading(false);
|
| 1143 |
+
}
|
| 1144 |
+
})
|
| 1145 |
+
.catch(error => {
|
| 1146 |
+
console.error('Error:', error);
|
| 1147 |
+
updateStatus('Generation failed: ' + error.message, cameraParams.length);
|
| 1148 |
+
hideDownloadProgress();
|
| 1149 |
+
showLoading(false);
|
| 1150 |
+
});
|
| 1151 |
+
}
|
| 1152 |
+
};
|
| 1153 |
+
|
| 1154 |
+
// Initialize renderer and GUI when DOM is ready
|
| 1155 |
+
function initializeApp() {
|
| 1156 |
+
try {
|
| 1157 |
+
// Debug layout
|
| 1158 |
+
console.log('Initializing app...');
|
| 1159 |
+
console.log('Center panel:', document.querySelector('.center-panel'));
|
| 1160 |
+
console.log('GUI container:', document.getElementById('gui-container'));
|
| 1161 |
+
console.log('Right panel:', document.querySelector('.right-panel'));
|
| 1162 |
+
|
| 1163 |
+
initializeRenderer();
|
| 1164 |
+
initializeGUI();
|
| 1165 |
+
console.log('App initialization complete');
|
| 1166 |
+
} catch (error) {
|
| 1167 |
+
console.error('App initialization failed:', error);
|
| 1168 |
+
}
|
| 1169 |
+
}
|
| 1170 |
+
|
| 1171 |
+
if (document.readyState === 'loading') {
|
| 1172 |
+
document.addEventListener('DOMContentLoaded', initializeApp);
|
| 1173 |
+
} else {
|
| 1174 |
+
initializeApp();
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
// =========================
|
| 1178 |
+
// Utility & Core Functions
|
| 1179 |
+
// =========================
|
| 1180 |
+
|
| 1181 |
+
// 计算插值相机
|
| 1182 |
+
function interpolateTwoCameras(startCamera, endCamera, _t) {
|
| 1183 |
+
const interpolatedCamera = new THREE.PerspectiveCamera(startCamera.fov, startCamera.aspect);
|
| 1184 |
+
|
| 1185 |
+
// 如果_t接近0,直接使用startCamera
|
| 1186 |
+
if (_t < 1e-6) {
|
| 1187 |
+
interpolatedCamera.position.copy(startCamera.position);
|
| 1188 |
+
interpolatedCamera.quaternion.copy(startCamera.quaternion);
|
| 1189 |
+
}
|
| 1190 |
+
// 如果_t接近1,直接使用endCamera
|
| 1191 |
+
else if (_t > 1 - 1e-6) {
|
| 1192 |
+
interpolatedCamera.position.copy(endCamera.position);
|
| 1193 |
+
interpolatedCamera.quaternion.copy(endCamera.quaternion);
|
| 1194 |
+
}
|
| 1195 |
+
// 否则进行插值
|
| 1196 |
+
else {
|
| 1197 |
+
interpolatedCamera.position.copy(startCamera.position).lerp(endCamera.position, _t);
|
| 1198 |
+
interpolatedCamera.quaternion.copy(startCamera.quaternion).slerp(endCamera.quaternion, _t);
|
| 1199 |
+
}
|
| 1200 |
+
|
| 1201 |
+
return interpolatedCamera;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
function interpolateCameras(cameras, M) {
|
| 1205 |
+
const interpolatedCameras = [];
|
| 1206 |
+
|
| 1207 |
+
if (cameras.length === 0) {
|
| 1208 |
+
return interpolatedCameras;
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
if (cameras.length === 1) {
|
| 1212 |
+
// 如果只有一个相机,重复使用它
|
| 1213 |
+
for (let i = 0; i < M; i++) {
|
| 1214 |
+
interpolatedCameras.push(cameras[0]);
|
| 1215 |
+
}
|
| 1216 |
+
return interpolatedCameras;
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
for (let i = 0; i < M; i++) {
|
| 1220 |
+
const t = i / (M - 1);
|
| 1221 |
+
const startIndex = Math.min(Math.floor(t * (cameras.length - 1)), cameras.length - 2);
|
| 1222 |
+
const endIndex = startIndex + 1;
|
| 1223 |
+
const startCamera = cameras[startIndex];
|
| 1224 |
+
const endCamera = cameras[endIndex];
|
| 1225 |
+
const _t = t * (cameras.length - 1) - startIndex;
|
| 1226 |
+
const interpolatedCamera = interpolateTwoCameras(startCamera, endCamera, _t);
|
| 1227 |
+
interpolatedCameras.push(interpolatedCamera);
|
| 1228 |
+
}
|
| 1229 |
+
return interpolatedCameras;
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
// 创建立方体的splat可视化
|
| 1233 |
+
function createCubeSplat(size = 0.1, pointColor = [1, 1, 1]) {
|
| 1234 |
+
const cubeSplat = new SplatMesh({
|
| 1235 |
+
constructSplats: (splats) => {
|
| 1236 |
+
const NUM_SPLATS_PER_EDGE = 1000;
|
| 1237 |
+
const scales = new THREE.Vector3().setScalar(0.002);
|
| 1238 |
+
const quaternion = new THREE.Quaternion();
|
| 1239 |
+
const opacity = 1;
|
| 1240 |
+
const color = new THREE.Color(...pointColor);
|
| 1241 |
+
|
| 1242 |
+
// 立方体的8个顶点
|
| 1243 |
+
const halfSize = size / 2;
|
| 1244 |
+
const vertices = [
|
| 1245 |
+
new THREE.Vector3(-halfSize, -halfSize, -halfSize), // 0: 左下后
|
| 1246 |
+
new THREE.Vector3(halfSize, -halfSize, -halfSize), // 1: 右下后
|
| 1247 |
+
new THREE.Vector3(halfSize, halfSize, -halfSize), // 2: 右上后
|
| 1248 |
+
new THREE.Vector3(-halfSize, halfSize, -halfSize), // 3: 左上后
|
| 1249 |
+
new THREE.Vector3(-halfSize, -halfSize, halfSize), // 4: 左下前
|
| 1250 |
+
new THREE.Vector3(halfSize, -halfSize, halfSize), // 5: 右下前
|
| 1251 |
+
new THREE.Vector3(halfSize, halfSize, halfSize), // 6: 右上前
|
| 1252 |
+
new THREE.Vector3(-halfSize, halfSize, halfSize), // 7: 左上前
|
| 1253 |
+
];
|
| 1254 |
+
|
| 1255 |
+
// 立方体的12条边
|
| 1256 |
+
const edges = [
|
| 1257 |
+
[0, 1], [1, 2], [2, 3], [3, 0], // 后面4条边
|
| 1258 |
+
[4, 5], [5, 6], [6, 7], [7, 4], // 前面4条边
|
| 1259 |
+
[0, 4], [1, 5], [2, 6], [3, 7], // 连接前后4条边
|
| 1260 |
+
];
|
| 1261 |
+
|
| 1262 |
+
// 为每条边生成splat点
|
| 1263 |
+
for (let i = 0; i < edges.length; i++) {
|
| 1264 |
+
const start = vertices[edges[i][0]];
|
| 1265 |
+
const end = vertices[edges[i][1]];
|
| 1266 |
+
for (let j = 0; j < NUM_SPLATS_PER_EDGE; j++) {
|
| 1267 |
+
const point = new THREE.Vector3().lerpVectors(start, end, j / NUM_SPLATS_PER_EDGE);
|
| 1268 |
+
splats.pushSplat(point, scales, quaternion, opacity, color);
|
| 1269 |
+
}
|
| 1270 |
+
}
|
| 1271 |
+
},
|
| 1272 |
+
});
|
| 1273 |
+
return cubeSplat;
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
// 创建相机锥体的splat可视化
|
| 1277 |
+
function createCameraSplat(camera, pointColor = [1, 1, 1]) {
|
| 1278 |
+
const cameraSplat = new SplatMesh({
|
| 1279 |
+
constructSplats: (splats) => {
|
| 1280 |
+
const NUM_SPLATS_PER_EDGE = 1000;
|
| 1281 |
+
const LENGTH_PER_EDGE = 0.1;
|
| 1282 |
+
const center = new THREE.Vector3();
|
| 1283 |
+
const scales = new THREE.Vector3().setScalar(0.001);
|
| 1284 |
+
const quaternion = new THREE.Quaternion();
|
| 1285 |
+
const opacity = 1;
|
| 1286 |
+
const color = new THREE.Color(...pointColor);
|
| 1287 |
+
|
| 1288 |
+
const H = 1000;
|
| 1289 |
+
const W = 1000 * camera.aspect;
|
| 1290 |
+
const fx = 0.5 * H / Math.tan(0.5 * camera.fov * Math.PI / 180);
|
| 1291 |
+
const fy = 0.5 * H / Math.tan(0.5 * camera.fov * Math.PI / 180);
|
| 1292 |
+
|
| 1293 |
+
const xt = (0 - W / 2 + 0.5) / fy;
|
| 1294 |
+
const xb = (W - W / 2 + 0.5) / fy;
|
| 1295 |
+
const yl = - (0 - H / 2 + 0.5) / fx;
|
| 1296 |
+
const yr = - (H - H / 2 + 0.5) / fx;
|
| 1297 |
+
|
| 1298 |
+
const lt = new THREE.Vector3(xt * LENGTH_PER_EDGE, yl * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
|
| 1299 |
+
const rt = new THREE.Vector3(xt * LENGTH_PER_EDGE, yr * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
|
| 1300 |
+
const lb = new THREE.Vector3(xb * LENGTH_PER_EDGE, yl * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
|
| 1301 |
+
const rb = new THREE.Vector3(xb * LENGTH_PER_EDGE, yr * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
|
| 1302 |
+
|
| 1303 |
+
const lines = [
|
| 1304 |
+
[center, lt], [center, rt], [center, lb], [center, rb],
|
| 1305 |
+
[lt, rt], [lt, lb], [rt, rb], [lb, rb],
|
| 1306 |
+
];
|
| 1307 |
+
|
| 1308 |
+
for (let i = 0; i < lines.length; i++) {
|
| 1309 |
+
for (let j = 0; j < NUM_SPLATS_PER_EDGE; j++) {
|
| 1310 |
+
const point = new THREE.Vector3().lerpVectors(lines[i][0], lines[i][1], j / NUM_SPLATS_PER_EDGE);
|
| 1311 |
+
splats.pushSplat(point, scales, quaternion, opacity, color);
|
| 1312 |
+
}
|
| 1313 |
+
}
|
| 1314 |
+
},
|
| 1315 |
+
});
|
| 1316 |
+
cameraSplat.quaternion.copy(camera.quaternion);
|
| 1317 |
+
cameraSplat.position.copy(camera.position);
|
| 1318 |
+
return cameraSplat;
|
| 1319 |
+
}
|
| 1320 |
+
|
| 1321 |
+
// 生成相机轨迹模板
|
| 1322 |
+
function generateCameraTrajectory(trajectoryType) {
|
| 1323 |
+
if (trajectoryType === "Manual") {
|
| 1324 |
+
updateStatus('Manual mode: Use Space to record cameras manually', cameraParams.length);
|
| 1325 |
+
return;
|
| 1326 |
+
}
|
| 1327 |
+
|
| 1328 |
+
// 检查FOV是否已固定
|
| 1329 |
+
if (!fixGenerationFOV) {
|
| 1330 |
+
updateStatus('Error: Please fix FOV first before generating trajectory', cameraParams.length);
|
| 1331 |
+
return;
|
| 1332 |
+
}
|
| 1333 |
+
|
| 1334 |
+
// 获取最后一个相机作为参考点
|
| 1335 |
+
let referenceCamera;
|
| 1336 |
+
if (cameraParams.length > 0) {
|
| 1337 |
+
// 使用最后一个已保存的相机作为参考
|
| 1338 |
+
const lastCamera = cameraParams[cameraParams.length - 1];
|
| 1339 |
+
referenceCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
|
| 1340 |
+
referenceCamera.position.copy(lastCamera.position);
|
| 1341 |
+
referenceCamera.quaternion.copy(lastCamera.quaternion);
|
| 1342 |
+
referenceCamera.updateProjectionMatrix();
|
| 1343 |
+
} else {
|
| 1344 |
+
// 如果没有已保存的相机,从原点开始
|
| 1345 |
+
referenceCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
|
| 1346 |
+
referenceCamera.position.set(0, 0, 0);
|
| 1347 |
+
referenceCamera.quaternion.set(0, 0, 0, 1);
|
| 1348 |
+
referenceCamera.updateProjectionMatrix();
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
// 对于orbit,计算所有相机围绕的目标点
|
| 1352 |
+
// 始终使用当前参考相机(最后一个相机)来计算目标点
|
| 1353 |
+
let orbitTarget = null;
|
| 1354 |
+
let orbitStartCamera = null;
|
| 1355 |
+
if (trajectoryType.includes("Orbit") && cameraParams.length > 0) {
|
| 1356 |
+
// 使用最后一个相机作为参考,计算其前方1单位的目标点
|
| 1357 |
+
orbitStartCamera = cameraParams[cameraParams.length - 1];
|
| 1358 |
+
orbitTarget = orbitStartCamera.position.clone().add(
|
| 1359 |
+
new THREE.Vector3(0, 0, -1).applyQuaternion(orbitStartCamera.quaternion)
|
| 1360 |
+
);
|
| 1361 |
+
console.log("Orbit target calculated from last camera:", orbitStartCamera.position, "->", orbitTarget);
|
| 1362 |
+
} else if (trajectoryType.includes("Orbit")) {
|
| 1363 |
+
// 如果没有已记录的相机,使用当前相机作为参考
|
| 1364 |
+
orbitStartCamera = referenceCamera;
|
| 1365 |
+
orbitTarget = referenceCamera.position.clone().add(
|
| 1366 |
+
new THREE.Vector3(0, 0, -1).applyQuaternion(referenceCamera.quaternion)
|
| 1367 |
+
);
|
| 1368 |
+
console.log("Orbit target calculated from current camera:", referenceCamera.position, "->", orbitTarget);
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
const cameras = [];
|
| 1372 |
+
const stepSize = 0.5; // 移动步长
|
| 1373 |
+
const totalOrbitAngle = 15 * Math.PI / 180; // 总共15度轨道
|
| 1374 |
+
|
| 1375 |
+
// 根据轨迹类型生成相机
|
| 1376 |
+
let numCameras = 1; // 默认生成1个相机
|
| 1377 |
+
if (trajectoryType.includes("Orbit")) {
|
| 1378 |
+
numCameras = 1; // 轨道运动生成1个相机
|
| 1379 |
+
console.log(`Generating ${numCameras} orbit camera with total angle ${totalOrbitAngle * 180 / Math.PI}°`);
|
| 1380 |
+
}
|
| 1381 |
+
|
| 1382 |
+
for (let i = 1; i <= numCameras; i++) {
|
| 1383 |
+
const newCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
|
| 1384 |
+
let position, quaternion;
|
| 1385 |
+
|
| 1386 |
+
switch (trajectoryType) {
|
| 1387 |
+
case "Move Forward":
|
| 1388 |
+
position = referenceCamera.position.clone();
|
| 1389 |
+
position.z -= stepSize;
|
| 1390 |
+
quaternion = referenceCamera.quaternion.clone();
|
| 1391 |
+
break;
|
| 1392 |
+
|
| 1393 |
+
case "Move Backward":
|
| 1394 |
+
position = referenceCamera.position.clone();
|
| 1395 |
+
position.z += stepSize;
|
| 1396 |
+
quaternion = referenceCamera.quaternion.clone();
|
| 1397 |
+
break;
|
| 1398 |
+
|
| 1399 |
+
case "Move Left":
|
| 1400 |
+
position = referenceCamera.position.clone();
|
| 1401 |
+
position.x -= stepSize;
|
| 1402 |
+
quaternion = referenceCamera.quaternion.clone();
|
| 1403 |
+
break;
|
| 1404 |
+
|
| 1405 |
+
case "Move Right":
|
| 1406 |
+
position = referenceCamera.position.clone();
|
| 1407 |
+
position.x += stepSize;
|
| 1408 |
+
quaternion = referenceCamera.quaternion.clone();
|
| 1409 |
+
break;
|
| 1410 |
+
|
| 1411 |
+
case "Orbit Left 15°":
|
| 1412 |
+
const radius = 1.0;
|
| 1413 |
+
// 左轨道:-15度
|
| 1414 |
+
const angle = -totalOrbitAngle;
|
| 1415 |
+
|
| 1416 |
+
console.log(`Camera ${i}: angle=${angle * 180 / Math.PI}° (Left)`);
|
| 1417 |
+
|
| 1418 |
+
// 计算轨道位置:在参考相机的局部坐标系中
|
| 1419 |
+
const localOrbitPos = new THREE.Vector3(
|
| 1420 |
+
Math.sin(angle) * radius,
|
| 1421 |
+
0,
|
| 1422 |
+
Math.cos(angle) * radius
|
| 1423 |
+
);
|
| 1424 |
+
|
| 1425 |
+
// 转换到世界坐标系:旋转到参考相机的方向
|
| 1426 |
+
const worldOrbitPos = localOrbitPos.applyQuaternion(orbitStartCamera.quaternion);
|
| 1427 |
+
|
| 1428 |
+
// 最终位置:从目标点出发,加上世界坐标系中的偏移
|
| 1429 |
+
position = orbitTarget.clone().add(worldOrbitPos);
|
| 1430 |
+
|
| 1431 |
+
console.log(`Orbit Left camera ${i}: localPos=`, localOrbitPos, 'worldPos=', worldOrbitPos, 'finalPos=', position);
|
| 1432 |
+
|
| 1433 |
+
// 朝向:所有相机都朝向圆心(目标点)
|
| 1434 |
+
const lookDirection = orbitTarget.clone().sub(position).normalize();
|
| 1435 |
+
quaternion = new THREE.Quaternion().setFromUnitVectors(
|
| 1436 |
+
new THREE.Vector3(0, 0, -1),
|
| 1437 |
+
lookDirection
|
| 1438 |
+
);
|
| 1439 |
+
|
| 1440 |
+
console.log(`Orbit Left camera ${i}: quaternion=`, quaternion);
|
| 1441 |
+
break;
|
| 1442 |
+
|
| 1443 |
+
case "Orbit Right 15°":
|
| 1444 |
+
const radiusRight = 1.0;
|
| 1445 |
+
// 右轨道:+15度
|
| 1446 |
+
const angleRight = totalOrbitAngle;
|
| 1447 |
+
|
| 1448 |
+
console.log(`Camera ${i}: angle=${angleRight * 180 / Math.PI}° (Right)`);
|
| 1449 |
+
|
| 1450 |
+
// 计算轨道位置:在参考相机的局部坐标系中
|
| 1451 |
+
const localOrbitPosRight = new THREE.Vector3(
|
| 1452 |
+
Math.sin(angleRight) * radiusRight,
|
| 1453 |
+
0,
|
| 1454 |
+
Math.cos(angleRight) * radiusRight
|
| 1455 |
+
);
|
| 1456 |
+
|
| 1457 |
+
// 转换到世界坐标系:旋转到参考相机的方向
|
| 1458 |
+
const worldOrbitPosRight = localOrbitPosRight.applyQuaternion(orbitStartCamera.quaternion);
|
| 1459 |
+
|
| 1460 |
+
// 最终位置:从目标点出发,加上世界坐标系中的偏移
|
| 1461 |
+
position = orbitTarget.clone().add(worldOrbitPosRight);
|
| 1462 |
+
|
| 1463 |
+
console.log(`Orbit Right camera ${i}: localPos=`, localOrbitPosRight, 'worldPos=', worldOrbitPosRight, 'finalPos=', position);
|
| 1464 |
+
|
| 1465 |
+
// 朝向:所有相机都朝向圆心(目标点)
|
| 1466 |
+
const lookDirectionRight = orbitTarget.clone().sub(position).normalize();
|
| 1467 |
+
quaternion = new THREE.Quaternion().setFromUnitVectors(
|
| 1468 |
+
new THREE.Vector3(0, 0, -1),
|
| 1469 |
+
lookDirectionRight
|
| 1470 |
+
);
|
| 1471 |
+
|
| 1472 |
+
console.log(`Orbit Right camera ${i}: quaternion=`, quaternion);
|
| 1473 |
+
break;
|
| 1474 |
+
|
| 1475 |
+
|
| 1476 |
+
default:
|
| 1477 |
+
position = referenceCamera.position.clone();
|
| 1478 |
+
quaternion = referenceCamera.quaternion.clone();
|
| 1479 |
+
}
|
| 1480 |
+
|
| 1481 |
+
newCamera.position.copy(position);
|
| 1482 |
+
newCamera.quaternion.copy(quaternion);
|
| 1483 |
+
newCamera.updateProjectionMatrix();
|
| 1484 |
+
cameras.push(newCamera);
|
| 1485 |
+
}
|
| 1486 |
+
|
| 1487 |
+
// 添加相机到场景
|
| 1488 |
+
cameras.forEach(cam => {
|
| 1489 |
+
const cameraSplat = createCameraSplat(cam);
|
| 1490 |
+
cameraSplats.push(cameraSplat);
|
| 1491 |
+
cameraParams.push({
|
| 1492 |
+
position: cam.position.clone(),
|
| 1493 |
+
quaternion: cam.quaternion.clone(),
|
| 1494 |
+
fov: cam.fov,
|
| 1495 |
+
aspect: cam.aspect,
|
| 1496 |
+
});
|
| 1497 |
+
scene.add(cameraSplat);
|
| 1498 |
+
});
|
| 1499 |
+
|
| 1500 |
+
updateStatus(`Added ${cameras.length} cameras using ${trajectoryType} trajectory`, cameraParams.length);
|
| 1501 |
+
console.log(`Added ${cameras.length} cameras using ${trajectoryType} trajectory`);
|
| 1502 |
+
}
|
| 1503 |
+
|
| 1504 |
+
// =========================
|
| 1505 |
+
// GUI & User Interaction
|
| 1506 |
+
// =========================
|
| 1507 |
+
|
| 1508 |
+
// GUI 控件 - 延迟初始化
|
| 1509 |
+
function initializeGUI() {
|
| 1510 |
+
const guiContainer = document.getElementById('gui-container');
|
| 1511 |
+
if (guiContainer && !gui) {
|
| 1512 |
+
// Clear any existing content
|
| 1513 |
+
guiContainer.innerHTML = '';
|
| 1514 |
+
|
| 1515 |
+
gui = new GUI({ title: "FlashWorld Controls", container: guiContainer });
|
| 1516 |
+
console.log('GUI initialized in container:', guiContainer);
|
| 1517 |
+
|
| 1518 |
+
// Step 1: Configure Generation Settings
|
| 1519 |
+
const step1Folder = gui.addFolder('1. Configure Settings');
|
| 1520 |
+
step1Folder.add(guiOptions, "BackendAddress").name("Backend Address");
|
| 1521 |
+
|
| 1522 |
+
// FOV和Resolution控制器,初始时启用
|
| 1523 |
+
const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {
|
| 1524 |
+
camera.fov = value;
|
| 1525 |
+
camera.updateProjectionMatrix();
|
| 1526 |
+
});
|
| 1527 |
+
const resolutionController = step1Folder.add(guiOptions, "Resolution", supportedResolutions.map(
|
| 1528 |
+
r => `${r.frame}x${r.height}x${r.width}`
|
| 1529 |
+
)).name("Resolution (NxHxW)").onChange((value) => {
|
| 1530 |
+
updateCanvasSize();
|
| 1531 |
+
});
|
| 1532 |
+
|
| 1533 |
+
// Fix Configuration按钮放在最下面
|
| 1534 |
+
const fixGenerationFOVController = step1Folder.add(guiOptions, "fixGenerationFOV").name("Fix Configuration");
|
| 1535 |
+
step1Folder.open();
|
| 1536 |
+
|
| 1537 |
+
// Step 2: Set Up Camera Path
|
| 1538 |
+
const step2Folder = gui.addFolder('2. Set Up Camera Path');
|
| 1539 |
+
|
| 1540 |
+
// Camera trajectory templates
|
| 1541 |
+
const trajectoryFolder = step2Folder.addFolder('Camera Trajectory');
|
| 1542 |
+
|
| 1543 |
+
// 轨迹模式选择
|
| 1544 |
+
const trajectoryModeController = trajectoryFolder.add(guiOptions, "trajectoryMode", [
|
| 1545 |
+
"Manual",
|
| 1546 |
+
"Template",
|
| 1547 |
+
"JSON"
|
| 1548 |
+
]).name("Trajectory Mode");
|
| 1549 |
+
|
| 1550 |
+
// 模板类型选择(仅在Template模式下可用)
|
| 1551 |
+
const templateTypeController = trajectoryFolder.add(guiOptions, "templateType", [
|
| 1552 |
+
"Move Forward",
|
| 1553 |
+
"Move Backward",
|
| 1554 |
+
"Move Left",
|
| 1555 |
+
"Move Right",
|
| 1556 |
+
"Orbit Left 15°",
|
| 1557 |
+
"Orbit Right 15°"
|
| 1558 |
+
]).name("Template Type");
|
| 1559 |
+
|
| 1560 |
+
// 生成轨迹按钮
|
| 1561 |
+
const generateTrajectoryController = trajectoryFolder.add(guiOptions, "generateTrajectory").name("Generate Trajectory");
|
| 1562 |
+
|
| 1563 |
+
// 加载/保存JSON轨迹按钮
|
| 1564 |
+
const loadTrajectoryController = trajectoryFolder.add(guiOptions, "LoadTrajectoryFromJson").name("Load from JSON");
|
| 1565 |
+
const saveTrajectoryController = trajectoryFolder.add(guiOptions, "saveTrajectoryToJson").name("Save Trajectory");
|
| 1566 |
+
|
| 1567 |
+
// 清理相机按钮
|
| 1568 |
+
const clearAllCamerasController = trajectoryFolder.add(guiOptions, "clearAllCameras").name("Clear All Cameras");
|
| 1569 |
+
|
| 1570 |
+
// 初始状态:禁用所有轨迹相关控件
|
| 1571 |
+
templateTypeController.disable();
|
| 1572 |
+
generateTrajectoryController.disable();
|
| 1573 |
+
loadTrajectoryController.disable();
|
| 1574 |
+
|
| 1575 |
+
// 轨迹模式变化时的处理
|
| 1576 |
+
trajectoryModeController.onChange((value) => {
|
| 1577 |
+
if (value === "Manual") {
|
| 1578 |
+
templateTypeController.disable();
|
| 1579 |
+
generateTrajectoryController.disable();
|
| 1580 |
+
loadTrajectoryController.disable();
|
| 1581 |
+
} else if (value === "Template") {
|
| 1582 |
+
templateTypeController.enable();
|
| 1583 |
+
if (fixGenerationFOV) {
|
| 1584 |
+
generateTrajectoryController.enable();
|
| 1585 |
+
} else {
|
| 1586 |
+
generateTrajectoryController.disable();
|
| 1587 |
+
}
|
| 1588 |
+
loadTrajectoryController.disable();
|
| 1589 |
+
} else if (value === "JSON") {
|
| 1590 |
+
templateTypeController.disable();
|
| 1591 |
+
generateTrajectoryController.disable();
|
| 1592 |
+
if (fixGenerationFOV) {
|
| 1593 |
+
loadTrajectoryController.enable();
|
| 1594 |
+
} else {
|
| 1595 |
+
loadTrajectoryController.disable();
|
| 1596 |
+
}
|
| 1597 |
+
}
|
| 1598 |
+
});
|
| 1599 |
+
|
| 1600 |
+
// 当Configuration固定时启用轨迹生成
|
| 1601 |
+
const originalFixFOV = guiOptions.fixGenerationFOV;
|
| 1602 |
+
guiOptions.fixGenerationFOV = () => {
|
| 1603 |
+
originalFixFOV();
|
| 1604 |
+
|
| 1605 |
+
// Fix Configuration后禁用所有Step 1的控制器
|
| 1606 |
+
fovController.disable();
|
| 1607 |
+
resolutionController.disable();
|
| 1608 |
+
|
| 1609 |
+
// 根据当前轨迹模式启用相应控件
|
| 1610 |
+
if (guiOptions.trajectoryMode === "Template") {
|
| 1611 |
+
generateTrajectoryController.enable();
|
| 1612 |
+
} else if (guiOptions.trajectoryMode === "JSON") {
|
| 1613 |
+
loadTrajectoryController.enable();
|
| 1614 |
+
}
|
| 1615 |
+
updateStatus('Configuration fixed. You can now generate camera trajectory.', cameraParams.length);
|
| 1616 |
+
};
|
| 1617 |
+
|
| 1618 |
+
trajectoryFolder.open();
|
| 1619 |
+
|
| 1620 |
+
step2Folder.add(guiOptions, "VisualizeCameraSplats").name("Visualize Cameras").onChange((value) => {
|
| 1621 |
+
cameraSplats.forEach(cameraSplat => {
|
| 1622 |
+
cameraSplat.opacity = value ? 1 : 0;
|
| 1623 |
+
});
|
| 1624 |
+
});
|
| 1625 |
+
step2Folder.add(guiOptions, "VisualizeInterpolatedCameras").name("Visualize Interpolated Cameras").onChange((value) => {
|
| 1626 |
+
interpolatedCamerasSplats.forEach(interpolatedCameraSplat => {
|
| 1627 |
+
interpolatedCameraSplat.opacity = value ? 1 : 0;
|
| 1628 |
+
});
|
| 1629 |
+
});
|
| 1630 |
+
|
| 1631 |
+
// Store controllers globally so they can be accessed from guiOptions
|
| 1632 |
+
window.fixGenerationFOVController = fixGenerationFOVController;
|
| 1633 |
+
|
| 1634 |
+
// Step 3: Add Scene Prompts
|
| 1635 |
+
const step3Folder = gui.addFolder('3. Add Scene Prompts');
|
| 1636 |
+
step3Folder.add(guiOptions, "inputImagePrompt").name("Input Image Prompt");
|
| 1637 |
+
step3Folder.add(guiOptions, "inputTextPrompt").name("Input Text Prompt");
|
| 1638 |
+
step3Folder.add(guiOptions, "imageIndex", 0, 24, 1).name("Image Index");
|
| 1639 |
+
|
| 1640 |
+
|
| 1641 |
+
// Step 4: Generate Your Scene
|
| 1642 |
+
const step4Folder = gui.addFolder('4. Generate Scene');
|
| 1643 |
+
step4Folder.add(guiOptions, "generate").name("Generate!");
|
| 1644 |
+
step4Folder.open();
|
| 1645 |
+
|
| 1646 |
+
// Step 5: Trajectory Playback (Scrubber)
|
| 1647 |
+
const step5Folder = gui.addFolder('5. Trajectory Playback');
|
| 1648 |
+
step5Folder.add(guiOptions, 'playbackT', 0, 1, 0.001).name('Scrub (0-1)').onChange((value) => {
|
| 1649 |
+
// 首次拖动时记录用户相机状态,便于需要时恢复(可选)
|
| 1650 |
+
if (!userCameraState) {
|
| 1651 |
+
userCameraState = {
|
| 1652 |
+
position: camera.position.clone(),
|
| 1653 |
+
quaternion: camera.quaternion.clone(),
|
| 1654 |
+
fov: camera.fov
|
| 1655 |
+
};
|
| 1656 |
+
}
|
| 1657 |
+
setCameraByScrub(value);
|
| 1658 |
+
updateStatus(`Scrubbing trajectory: t=${value.toFixed(3)}`, cameraParams.length);
|
| 1659 |
+
});
|
| 1660 |
+
step5Folder.open();
|
| 1661 |
+
|
| 1662 |
+
}
|
| 1663 |
+
}
|
| 1664 |
+
|
| 1665 |
+
|
| 1666 |
+
// =========================
|
| 1667 |
+
// File Input (Image Prompt)
|
| 1668 |
+
// =========================
|
| 1669 |
+
const fileInput = document.querySelector("#file-input");
|
| 1670 |
+
fileInput.onchange = (event) => {
|
| 1671 |
+
const files = event.target.files;
|
| 1672 |
+
if (!files || files.length === 0) return;
|
| 1673 |
+
Array.from(files).forEach(file => {
|
| 1674 |
+
const reader = new FileReader();
|
| 1675 |
+
reader.onload = function(e) {
|
| 1676 |
+
console.log("Loaded image:", file.name, e.target.result);
|
| 1677 |
+
|
| 1678 |
+
// 获取当前Resolution
|
| 1679 |
+
let resolutionStr = guiOptions.Resolution;
|
| 1680 |
+
let [n, h, w] = resolutionStr.split('x').map(Number);
|
| 1681 |
+
|
| 1682 |
+
// 加载图片
|
| 1683 |
+
const img = new Image();
|
| 1684 |
+
img.onload = function() {
|
| 1685 |
+
window.inputImageResolution = { width: img.width, height: img.height };
|
| 1686 |
+
console.log("Input image resolution:", window.inputImageResolution);
|
| 1687 |
+
|
| 1688 |
+
// 计算center crop参数
|
| 1689 |
+
let scaleH = h / img.height;
|
| 1690 |
+
let scaleW = w / img.width;
|
| 1691 |
+
let scale = Math.max(scaleH, scaleW);
|
| 1692 |
+
let newW = Math.round(w / scale);
|
| 1693 |
+
let newH = Math.round(h / scale);
|
| 1694 |
+
let sx = Math.floor((img.width - newW) / 2);
|
| 1695 |
+
let sy = Math.floor((img.height - newH) / 2);
|
| 1696 |
+
|
| 1697 |
+
// 创建canvas进行center crop和resize
|
| 1698 |
+
const canvas = document.createElement('canvas');
|
| 1699 |
+
canvas.width = w;
|
| 1700 |
+
canvas.height = h;
|
| 1701 |
+
const ctx = canvas.getContext('2d');
|
| 1702 |
+
ctx.drawImage(
|
| 1703 |
+
img,
|
| 1704 |
+
sx, sy, newW, newH, // source crop
|
| 1705 |
+
0, 0, w, h // destination size
|
| 1706 |
+
);
|
| 1707 |
+
// 得到裁剪+缩放后的base64(用于后端)
|
| 1708 |
+
inputImageBase64 = canvas.toDataURL('image/png');
|
| 1709 |
+
// 更新预览为裁剪后的图
|
| 1710 |
+
const previewArea = document.getElementById('image-preview-area');
|
| 1711 |
+
const previewImg = document.getElementById('preview-img');
|
| 1712 |
+
if (previewImg && previewArea) {
|
| 1713 |
+
previewImg.src = inputImageBase64;
|
| 1714 |
+
previewArea.style.display = 'block';
|
| 1715 |
+
}
|
| 1716 |
+
// 记录传给后端的分辨率(已对齐为当前Resolution)
|
| 1717 |
+
window.inputImageResolution = { width: w, height: h };
|
| 1718 |
+
console.log("Cropped and resized image to:", w, h);
|
| 1719 |
+
};
|
| 1720 |
+
img.src = e.target.result;
|
| 1721 |
+
};
|
| 1722 |
+
reader.readAsDataURL(file);
|
| 1723 |
+
});
|
| 1724 |
+
|
| 1725 |
+
};
|
| 1726 |
+
|
| 1727 |
+
// =========================
|
| 1728 |
+
// File Input (JSON)
|
| 1729 |
+
// =========================
|
| 1730 |
+
// const jsonInput = document.querySelector("#json-input");
|
| 1731 |
+
// jsonInput.onchange = (event) => {
|
| 1732 |
+
// const files = event.target.files;
|
| 1733 |
+
// if (!files || files.length === 0) return;
|
| 1734 |
+
// const file = files[0];
|
| 1735 |
+
// const reader = new FileReader();
|
| 1736 |
+
// reader.onload = function(e) {
|
| 1737 |
+
// let jsonData;
|
| 1738 |
+
// try {
|
| 1739 |
+
// jsonData = JSON.parse(e.target.result);
|
| 1740 |
+
// } catch (error) {
|
| 1741 |
+
// alert("JSON parsing error: " + error);
|
| 1742 |
+
// console.error("JSON parsing error:", error);
|
| 1743 |
+
// return;
|
| 1744 |
+
// }
|
| 1745 |
+
|
| 1746 |
+
// // 清理所有已有的相机和插值相机
|
| 1747 |
+
// cameraSplats.forEach(splat => scene.remove(splat));
|
| 1748 |
+
// cameraSplats.length = 0;
|
| 1749 |
+
// cameraParams.length = 0;
|
| 1750 |
+
// interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
|
| 1751 |
+
// interpolatedCamerasSplats.length = 0;
|
| 1752 |
+
|
| 1753 |
+
// try {
|
| 1754 |
+
// // 兼容不同命名的字段
|
| 1755 |
+
// const imagePrompt = jsonData.image_prompt || jsonData.imagePrompt || null;
|
| 1756 |
+
// const textPrompt = jsonData.text_prompt || jsonData.textPrompt || "";
|
| 1757 |
+
// const cameras = jsonData.cameras || [];
|
| 1758 |
+
// const resolution = jsonData.resolution || [16, 480, 640];
|
| 1759 |
+
// const imageIndex = jsonData.image_index || jsonData.imageIndex || 0;
|
| 1760 |
+
|
| 1761 |
+
// console.log("Loaded JSON data:", {
|
| 1762 |
+
// imagePrompt,
|
| 1763 |
+
// textPrompt,
|
| 1764 |
+
// cameras: cameras.length,
|
| 1765 |
+
// resolution,
|
| 1766 |
+
// imageIndex
|
| 1767 |
+
// });
|
| 1768 |
+
|
| 1769 |
+
// // 处理图像提示
|
| 1770 |
+
// if (imagePrompt) {
|
| 1771 |
+
// inputImageBase64 = imagePrompt;
|
| 1772 |
+
// console.log("Image prompt loaded");
|
| 1773 |
+
// }
|
| 1774 |
+
|
| 1775 |
+
// // 设置文本提示
|
| 1776 |
+
// guiOptions.inputTextPrompt = textPrompt;
|
| 1777 |
+
// guiOptions.imageIndex = imageIndex;
|
| 1778 |
+
|
| 1779 |
+
// // 处理相机数据
|
| 1780 |
+
// if (cameras && cameras.length > 0) {
|
| 1781 |
+
// cameras.forEach(cameraData => {
|
| 1782 |
+
// // 解析分辨率
|
| 1783 |
+
// let aspect = 1.0;
|
| 1784 |
+
// if (Array.isArray(resolution) && resolution.length === 3) {
|
| 1785 |
+
// aspect = resolution[2] / resolution[1];
|
| 1786 |
+
// }
|
| 1787 |
+
// const cam = new THREE.PerspectiveCamera(60, aspect);
|
| 1788 |
+
|
| 1789 |
+
// // 设置位置
|
| 1790 |
+
// if (Array.isArray(cameraData.position) && cameraData.position.length === 3) {
|
| 1791 |
+
// cam.position.set(cameraData.position[0], cameraData.position[1], cameraData.position[2]);
|
| 1792 |
+
// }
|
| 1793 |
+
|
| 1794 |
+
// // 设置四元数
|
| 1795 |
+
// if (Array.isArray(cameraData.quaternion) && cameraData.quaternion.length === 4) {
|
| 1796 |
+
// // 注意:three.js的顺序是 (x, y, z, w)
|
| 1797 |
+
// cam.quaternion.set(
|
| 1798 |
+
// cameraData.quaternion[1],
|
| 1799 |
+
// cameraData.quaternion[2],
|
| 1800 |
+
// cameraData.quaternion[3],
|
| 1801 |
+
// cameraData.quaternion[0]
|
| 1802 |
+
// );
|
| 1803 |
+
// }
|
| 1804 |
+
|
| 1805 |
+
// // 设置FOV和焦距
|
| 1806 |
+
// if (cameraData.fx && cameraData.fy) {
|
| 1807 |
+
// // fx, fy: 焦距(像素)
|
| 1808 |
+
// // 假设分辨率为 [N, H, W]
|
| 1809 |
+
// // fov = 2 * atan(0.5 * H / fy) * 180 / PI
|
| 1810 |
+
// // 但原代码用的是 fx
|
| 1811 |
+
// let fov = 60;
|
| 1812 |
+
// if (cameraData.fx) {
|
| 1813 |
+
// fov = 2 * Math.atan(0.5 / cameraData.fx) * 180 / Math.PI;
|
| 1814 |
+
// }
|
| 1815 |
+
// cam.fov = fov;
|
| 1816 |
+
// cam.aspect = cameraData.fx / cameraData.fy;
|
| 1817 |
+
// cam.updateProjectionMatrix();
|
| 1818 |
+
// }
|
| 1819 |
+
|
| 1820 |
+
// const cameraSplat = createCameraSplat(cam);
|
| 1821 |
+
// cameraSplats.push(cameraSplat);
|
| 1822 |
+
// cameraParams.push({
|
| 1823 |
+
// position: cam.position.clone(),
|
| 1824 |
+
// quaternion: cam.quaternion.clone(),
|
| 1825 |
+
// fov: cam.fov,
|
| 1826 |
+
// aspect: cam.aspect,
|
| 1827 |
+
// });
|
| 1828 |
+
// scene.add(cameraSplat);
|
| 1829 |
+
// });
|
| 1830 |
+
// console.log(`Loaded ${cameras.length} cameras`);
|
| 1831 |
+
// }
|
| 1832 |
+
|
| 1833 |
+
// // 设置分辨率
|
| 1834 |
+
// if (Array.isArray(resolution) && resolution.length === 3) {
|
| 1835 |
+
// guiOptions.Resolution = `${resolution[0]}x${resolution[1]}x${resolution[2]}`;
|
| 1836 |
+
// }
|
| 1837 |
+
|
| 1838 |
+
// alert("JSON loaded");
|
| 1839 |
+
// } catch (error) {
|
| 1840 |
+
// alert("JSON data processing error: " + error);
|
| 1841 |
+
// console.error("JSON data processing error:", error);
|
| 1842 |
+
// }
|
| 1843 |
+
// };
|
| 1844 |
+
// reader.readAsText(file);
|
| 1845 |
+
// };
|
| 1846 |
+
|
| 1847 |
+
const jsonInput = document.querySelector("#json-input");
|
| 1848 |
+
jsonInput.onchange = (event) => {
|
| 1849 |
+
const files = event.target.files;
|
| 1850 |
+
if (!files || files.length === 0) return;
|
| 1851 |
+
const file = files[0];
|
| 1852 |
+
const reader = new FileReader();
|
| 1853 |
+
reader.onload = function(e) {
|
| 1854 |
+
let jsonData;
|
| 1855 |
+
try {
|
| 1856 |
+
jsonData = JSON.parse(e.target.result);
|
| 1857 |
+
} catch (error) {
|
| 1858 |
+
console.error("JSON parsing error:", error);
|
| 1859 |
+
return;
|
| 1860 |
+
}
|
| 1861 |
+
|
| 1862 |
+
// 检查是否是只加载轨迹
|
| 1863 |
+
const loadTrajectoryOnly = window.loadTrajectoryOnly;
|
| 1864 |
+
window.loadTrajectoryOnly = false; // 重置标志
|
| 1865 |
+
|
| 1866 |
+
if (loadTrajectoryOnly) {
|
| 1867 |
+
// 只加载轨迹:清理所有已有的相机和插值相机
|
| 1868 |
+
cameraSplats.forEach(splat => scene.remove(splat));
|
| 1869 |
+
cameraSplats.length = 0;
|
| 1870 |
+
cameraParams.length = 0;
|
| 1871 |
+
interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
|
| 1872 |
+
interpolatedCamerasSplats.length = 0;
|
| 1873 |
+
} else {
|
| 1874 |
+
// 加载完整JSON:清理所有已有的相���和插值相机
|
| 1875 |
+
cameraSplats.forEach(splat => scene.remove(splat));
|
| 1876 |
+
cameraSplats.length = 0;
|
| 1877 |
+
cameraParams.length = 0;
|
| 1878 |
+
interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
|
| 1879 |
+
interpolatedCamerasSplats.length = 0;
|
| 1880 |
+
}
|
| 1881 |
+
|
| 1882 |
+
try {
|
| 1883 |
+
// 兼容不同命名的字段
|
| 1884 |
+
const imagePrompt = jsonData.image_prompt || jsonData.imagePrompt || null;
|
| 1885 |
+
const textPrompt = jsonData.text_prompt || jsonData.textPrompt || "";
|
| 1886 |
+
const cameras = jsonData.cameras || [];
|
| 1887 |
+
const resolution = jsonData.resolution || [16, 480, 640];
|
| 1888 |
+
const imageIndex = jsonData.image_index || jsonData.imageIndex || 0;
|
| 1889 |
+
|
| 1890 |
+
console.log("Loaded JSON data:", {
|
| 1891 |
+
imagePrompt,
|
| 1892 |
+
textPrompt,
|
| 1893 |
+
cameras: cameras.length,
|
| 1894 |
+
resolution,
|
| 1895 |
+
imageIndex
|
| 1896 |
+
});
|
| 1897 |
+
|
| 1898 |
+
// 处理图像提示(仅在非轨迹加载模式下)
|
| 1899 |
+
if (!loadTrajectoryOnly && imagePrompt) {
|
| 1900 |
+
inputImageBase64 = imagePrompt;
|
| 1901 |
+
console.log("Image prompt loaded");
|
| 1902 |
+
}
|
| 1903 |
+
|
| 1904 |
+
// 设置文本提示(仅在非轨迹加载模式下)
|
| 1905 |
+
if (!loadTrajectoryOnly) {
|
| 1906 |
+
guiOptions.inputTextPrompt = textPrompt;
|
| 1907 |
+
guiOptions.imageIndex = imageIndex;
|
| 1908 |
+
}
|
| 1909 |
+
|
| 1910 |
+
// 处理相机数据
|
| 1911 |
+
if (cameras && cameras.length > 0) {
|
| 1912 |
+
let jsonFirstCamera = null;
|
| 1913 |
+
let jsonFirstPosition = null;
|
| 1914 |
+
let jsonFirstQuaternion = null;
|
| 1915 |
+
|
| 1916 |
+
// 首先获取JSON中第一个相机的位置和四元数
|
| 1917 |
+
if (loadTrajectoryOnly && cameras.length > 0) {
|
| 1918 |
+
const firstCameraData = cameras[0];
|
| 1919 |
+
if (Array.isArray(firstCameraData.position) && firstCameraData.position.length === 3) {
|
| 1920 |
+
jsonFirstPosition = new THREE.Vector3(
|
| 1921 |
+
firstCameraData.position[0],
|
| 1922 |
+
firstCameraData.position[1],
|
| 1923 |
+
firstCameraData.position[2]
|
| 1924 |
+
);
|
| 1925 |
+
}
|
| 1926 |
+
if (Array.isArray(firstCameraData.quaternion) && firstCameraData.quaternion.length === 4) {
|
| 1927 |
+
jsonFirstQuaternion = new THREE.Quaternion(
|
| 1928 |
+
firstCameraData.quaternion[1],
|
| 1929 |
+
firstCameraData.quaternion[2],
|
| 1930 |
+
firstCameraData.quaternion[3],
|
| 1931 |
+
firstCameraData.quaternion[0]
|
| 1932 |
+
);
|
| 1933 |
+
}
|
| 1934 |
+
}
|
| 1935 |
+
|
| 1936 |
+
cameras.forEach((cameraData, index) => {
|
| 1937 |
+
// 解析分辨率
|
| 1938 |
+
let aspect = 1.0;
|
| 1939 |
+
if (Array.isArray(resolution) && resolution.length === 3) {
|
| 1940 |
+
aspect = resolution[2] / resolution[1];
|
| 1941 |
+
} else {
|
| 1942 |
+
aspect = guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1];
|
| 1943 |
+
}
|
| 1944 |
+
|
| 1945 |
+
// 根据加载模式决定FOV
|
| 1946 |
+
let fov = 60;
|
| 1947 |
+
if (loadTrajectoryOnly) {
|
| 1948 |
+
// 轨迹加载:使用GUI中设定的FOV
|
| 1949 |
+
fov = guiOptions.FOV;
|
| 1950 |
+
} else {
|
| 1951 |
+
// 完整JSON加载:使用JSON中的FOV或默认值
|
| 1952 |
+
if (cameraData.fx && cameraData.fy) {
|
| 1953 |
+
fov = 2 * Math.atan(0.5 / cameraData.fx) * 180 / Math.PI;
|
| 1954 |
+
}
|
| 1955 |
+
}
|
| 1956 |
+
|
| 1957 |
+
const cam = new THREE.PerspectiveCamera(fov, aspect);
|
| 1958 |
+
|
| 1959 |
+
// 设置位置和四元数
|
| 1960 |
+
if (Array.isArray(cameraData.position) && cameraData.position.length === 3) {
|
| 1961 |
+
cam.position.set(cameraData.position[0], cameraData.position[1], cameraData.position[2]);
|
| 1962 |
+
}
|
| 1963 |
+
|
| 1964 |
+
if (Array.isArray(cameraData.quaternion) && cameraData.quaternion.length === 4) {
|
| 1965 |
+
// 注意:three.js的顺序是 (x, y, z, w)
|
| 1966 |
+
cam.quaternion.set(
|
| 1967 |
+
cameraData.quaternion[1],
|
| 1968 |
+
cameraData.quaternion[2],
|
| 1969 |
+
cameraData.quaternion[3],
|
| 1970 |
+
cameraData.quaternion[0]
|
| 1971 |
+
);
|
| 1972 |
+
}
|
| 1973 |
+
|
| 1974 |
+
// 轨迹加载:第一个相机强制设置为原点
|
| 1975 |
+
// if (loadTrajectoryOnly && index === 0) {
|
| 1976 |
+
// cam.position.set(0, 0, 0);
|
| 1977 |
+
// cam.quaternion.set(0, 0, 0, 1);
|
| 1978 |
+
// }
|
| 1979 |
+
|
| 1980 |
+
// 轨迹加载:归一化到相对于固定FOV相机的位置
|
| 1981 |
+
if (loadTrajectoryOnly && jsonFirstPosition && jsonFirstQuaternion) {
|
| 1982 |
+
// 参考Python代码的归一化逻辑
|
| 1983 |
+
// 1. 计算JSON第一个相机的c2w矩阵
|
| 1984 |
+
const jsonFirstC2W = new THREE.Matrix4();
|
| 1985 |
+
jsonFirstC2W.compose(jsonFirstPosition, jsonFirstQuaternion, new THREE.Vector3(1, 1, 1));
|
| 1986 |
+
|
| 1987 |
+
// 2. 计算当前相机的c2w矩阵
|
| 1988 |
+
const currentC2W = new THREE.Matrix4();
|
| 1989 |
+
currentC2W.compose(cam.position, cam.quaternion, new THREE.Vector3(1, 1, 1));
|
| 1990 |
+
|
| 1991 |
+
// 3. 计算相对变换:ref_w2c @ current_c2w
|
| 1992 |
+
const refW2C = jsonFirstC2W.clone().invert();
|
| 1993 |
+
const relativeTransform = refW2C.clone().multiply(currentC2W);
|
| 1994 |
+
|
| 1995 |
+
// 4. 将相对变换应用到原点相机上(作为参考)
|
| 1996 |
+
const fixedC2W = new THREE.Matrix4();
|
| 1997 |
+
fixedC2W.compose(new THREE.Vector3(0, 0, 0), new THREE.Quaternion(0, 0, 0, 1), new THREE.Vector3(1, 1, 1));
|
| 1998 |
+
|
| 1999 |
+
const newTransform = fixedC2W.clone().multiply(relativeTransform);
|
| 2000 |
+
|
| 2001 |
+
// 5. 提取新的位置和旋转
|
| 2002 |
+
const newPosition = new THREE.Vector3();
|
| 2003 |
+
const newQuaternion = new THREE.Quaternion();
|
| 2004 |
+
const newScale = new THREE.Vector3();
|
| 2005 |
+
newTransform.decompose(newPosition, newQuaternion, newScale);
|
| 2006 |
+
|
| 2007 |
+
cam.position.copy(newPosition);
|
| 2008 |
+
cam.quaternion.copy(newQuaternion);
|
| 2009 |
+
}
|
| 2010 |
+
|
| 2011 |
+
// 设置FOV和焦距(仅在非轨迹加载模式下)
|
| 2012 |
+
if (!loadTrajectoryOnly && cameraData.fx && cameraData.fy) {
|
| 2013 |
+
cam.fov = fov;
|
| 2014 |
+
cam.aspect = cameraData.fx / cameraData.fy;
|
| 2015 |
+
cam.updateProjectionMatrix();
|
| 2016 |
+
} else if (loadTrajectoryOnly) {
|
| 2017 |
+
// 轨迹加载:使用GUI中设定的FOV和aspect
|
| 2018 |
+
cam.fov = fov;
|
| 2019 |
+
cam.aspect = aspect;
|
| 2020 |
+
cam.updateProjectionMatrix();
|
| 2021 |
+
}
|
| 2022 |
+
|
| 2023 |
+
const cameraSplat = createCameraSplat(cam);
|
| 2024 |
+
cameraSplats.push(cameraSplat);
|
| 2025 |
+
cameraParams.push({
|
| 2026 |
+
position: cam.position.clone(),
|
| 2027 |
+
quaternion: cam.quaternion.clone(),
|
| 2028 |
+
fov: cam.fov,
|
| 2029 |
+
aspect: cam.aspect,
|
| 2030 |
+
});
|
| 2031 |
+
scene.add(cameraSplat);
|
| 2032 |
+
});
|
| 2033 |
+
|
| 2034 |
+
console.log(cameraParams);
|
| 2035 |
+
}
|
| 2036 |
+
|
| 2037 |
+
// 设置分辨率(仅在非轨迹加载模式下)
|
| 2038 |
+
if (!loadTrajectoryOnly && Array.isArray(resolution) && resolution.length === 3) {
|
| 2039 |
+
guiOptions.Resolution = `${resolution[0]}x${resolution[1]}x${resolution[2]}`;
|
| 2040 |
+
}
|
| 2041 |
+
|
| 2042 |
+
// 显示成功消息
|
| 2043 |
+
if (loadTrajectoryOnly) {
|
| 2044 |
+
updateStatus(`Trajectory loaded: ${cameras.length} cameras`, cameraParams.length);
|
| 2045 |
+
} else {
|
| 2046 |
+
}
|
| 2047 |
+
} catch (error) {
|
| 2048 |
+
console.error("JSON data processing error:", error);
|
| 2049 |
+
}
|
| 2050 |
+
};
|
| 2051 |
+
reader.readAsText(file);
|
| 2052 |
+
};
|
| 2053 |
+
|
| 2054 |
+
// =========================
|
| 2055 |
+
// Keyboard Controls
|
| 2056 |
+
// =========================
|
| 2057 |
+
document.addEventListener('keypress', (event) => {
|
| 2058 |
+
if (event.code === 'Space') {
|
| 2059 |
+
if (!fixGenerationFOV) {
|
| 2060 |
+
updateStatus('Please fix Generation FOV first', cameraParams.length);
|
| 2061 |
+
return;
|
| 2062 |
+
}
|
| 2063 |
+
// 记录当前相机的pose
|
| 2064 |
+
const new_camera = camera.clone();
|
| 2065 |
+
new_camera.fov = guiOptions.FOV;
|
| 2066 |
+
new_camera.aspect = guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1];
|
| 2067 |
+
new_camera.updateProjectionMatrix();
|
| 2068 |
+
|
| 2069 |
+
const cameraSplat = createCameraSplat(new_camera);
|
| 2070 |
+
cameraSplats.push(cameraSplat);
|
| 2071 |
+
cameraParams.push({
|
| 2072 |
+
position: new_camera.position.clone(),
|
| 2073 |
+
quaternion: new_camera.quaternion.clone(),
|
| 2074 |
+
fov: new_camera.fov,
|
| 2075 |
+
aspect: new_camera.aspect,
|
| 2076 |
+
});
|
| 2077 |
+
scene.add(cameraSplat);
|
| 2078 |
+
|
| 2079 |
+
updateStatus(`Camera ${cameraParams.length} recorded. Press Space for more or Generate!`, cameraParams.length);
|
| 2080 |
+
|
| 2081 |
+
console.log(new_camera.getFocalLength());
|
| 2082 |
+
}
|
| 2083 |
+
});
|
| 2084 |
+
|
| 2085 |
+
// =========================
|
| 2086 |
+
// Scene Initialization
|
| 2087 |
+
// =========================
|
| 2088 |
+
|
| 2089 |
+
// Initialize status
|
| 2090 |
+
updateStatus('FlashWorld initialized. Configure settings to begin.', 0);
|
| 2091 |
+
|
| 2092 |
+
// Add cube splat to the scene
|
| 2093 |
+
let instructionSplat = createCubeSplat(0.25, [1, 1, 1]);
|
| 2094 |
+
instructionSplat.position.set(0, 0, -1);
|
| 2095 |
+
scene.add(instructionSplat);
|
| 2096 |
+
console.log('Cube splat added to scene');
|
| 2097 |
+
|
| 2098 |
+
// Handle window resize
|
| 2099 |
+
window.addEventListener('resize', () => {
|
| 2100 |
+
console.log('Window resized, updating canvas...');
|
| 2101 |
+
// Update canvas size based on current resolution
|
| 2102 |
+
updateCanvasSize();
|
| 2103 |
+
});
|
| 2104 |
+
|
| 2105 |
+
// =========================
|
| 2106 |
+
// Animation Loop
|
| 2107 |
+
// =========================
|
| 2108 |
+
let lastTime = null;
|
| 2109 |
+
|
| 2110 |
+
renderer.setAnimationLoop(function animate(time) {
|
| 2111 |
+
const deltaTime = time - (lastTime || time);
|
| 2112 |
+
lastTime = time;
|
| 2113 |
+
|
| 2114 |
+
// Rotate the cube splat
|
| 2115 |
+
if (instructionSplat) {
|
| 2116 |
+
// instructionSplat.rotation.x += deltaTime / 4000; // 绕X轴旋转
|
| 2117 |
+
instructionSplat.rotation.y += deltaTime / 5000; // 绕Y轴旋转
|
| 2118 |
+
instructionSplat.rotation.z += deltaTime / 6000; // 绕Z轴旋转
|
| 2119 |
+
}
|
| 2120 |
+
|
| 2121 |
+
// No active playback loop; scrubber directly sets camera
|
| 2122 |
+
|
| 2123 |
+
controls.update(camera);
|
| 2124 |
+
renderer.render(scene, camera);
|
| 2125 |
+
|
| 2126 |
+
});
|
| 2127 |
+
|
| 2128 |
+
</script>
|
| 2129 |
+
</body>
|
| 2130 |
+
</html>
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .autoencoder_kl_wan import AutoencoderKLWan
|
| 2 |
+
from .transformer_wan import WanTransformer3DModel
|
| 3 |
+
from .reconstruction_model import WANDecoderPixelAligned3DGSReconstructionModel
|
| 4 |
+
|
| 5 |
+
__all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WANDecoderPixelAligned3DGSReconstructionModel"]
|
models/autoencoder_kl_wan.py
ADDED
|
@@ -0,0 +1,1467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 24 |
+
from diffusers.utils import logging
|
| 25 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 26 |
+
from diffusers.models.activations import get_activation
|
| 27 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 29 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 30 |
+
|
| 31 |
+
import einops
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
CACHE_T = 2
|
| 37 |
+
|
| 38 |
+
class AvgDown3D(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
in_channels,
|
| 43 |
+
out_channels,
|
| 44 |
+
factor_t,
|
| 45 |
+
factor_s=1,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.in_channels = in_channels
|
| 49 |
+
self.out_channels = out_channels
|
| 50 |
+
self.factor_t = factor_t
|
| 51 |
+
self.factor_s = factor_s
|
| 52 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 53 |
+
|
| 54 |
+
assert in_channels * self.factor % out_channels == 0
|
| 55 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
if not ((x.shape[2] == 1 and self.group_size >= self.factor) or self.factor_t == 1):
|
| 59 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t)
|
| 60 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 61 |
+
x = F.pad(x, pad)
|
| 62 |
+
B, C, T, H, W = x.shape
|
| 63 |
+
x = x.view(
|
| 64 |
+
B,
|
| 65 |
+
C,
|
| 66 |
+
T // self.factor_t,
|
| 67 |
+
self.factor_t,
|
| 68 |
+
H // self.factor_s,
|
| 69 |
+
self.factor_s,
|
| 70 |
+
W // self.factor_s,
|
| 71 |
+
self.factor_s,
|
| 72 |
+
)
|
| 73 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 74 |
+
x = x.view(
|
| 75 |
+
B,
|
| 76 |
+
C * self.factor,
|
| 77 |
+
T // self.factor_t,
|
| 78 |
+
H // self.factor_s,
|
| 79 |
+
W // self.factor_s,
|
| 80 |
+
)
|
| 81 |
+
x = x.view(
|
| 82 |
+
B,
|
| 83 |
+
self.out_channels,
|
| 84 |
+
self.group_size,
|
| 85 |
+
T // self.factor_t,
|
| 86 |
+
H // self.factor_s,
|
| 87 |
+
W // self.factor_s,
|
| 88 |
+
)
|
| 89 |
+
x = x.mean(dim=2)
|
| 90 |
+
return x
|
| 91 |
+
else:
|
| 92 |
+
# print(1)
|
| 93 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 94 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 95 |
+
B, C, T, H, W = x.shape
|
| 96 |
+
x = x.view(
|
| 97 |
+
B,
|
| 98 |
+
C,
|
| 99 |
+
T,
|
| 100 |
+
1,
|
| 101 |
+
H // self.factor_s,
|
| 102 |
+
self.factor_s,
|
| 103 |
+
W // self.factor_s,
|
| 104 |
+
self.factor_s,
|
| 105 |
+
)
|
| 106 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 107 |
+
x = x.view(
|
| 108 |
+
B,
|
| 109 |
+
C * self.factor // self.factor_t,
|
| 110 |
+
T,
|
| 111 |
+
H // self.factor_s,
|
| 112 |
+
W // self.factor_s,
|
| 113 |
+
)
|
| 114 |
+
x = x.view(
|
| 115 |
+
B,
|
| 116 |
+
self.out_channels,
|
| 117 |
+
self.group_size // self.factor_t,
|
| 118 |
+
T,
|
| 119 |
+
H // self.factor_s,
|
| 120 |
+
W // self.factor_s,
|
| 121 |
+
)
|
| 122 |
+
# 因为pad的是0,所以按理说除以factor_t后值才是对的
|
| 123 |
+
x = x.mean(dim=2) / (pad_t + 1)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
class DupUp3D(nn.Module):
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
in_channels: int,
|
| 131 |
+
out_channels: int,
|
| 132 |
+
factor_t,
|
| 133 |
+
factor_s=1,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.in_channels = in_channels
|
| 137 |
+
self.out_channels = out_channels
|
| 138 |
+
|
| 139 |
+
self.factor_t = factor_t
|
| 140 |
+
self.factor_s = factor_s
|
| 141 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 142 |
+
|
| 143 |
+
assert out_channels * self.factor % in_channels == 0
|
| 144 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 145 |
+
|
| 146 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 147 |
+
if not (first_chunk and x.shape[2] == 1):
|
| 148 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 149 |
+
x = x.view(
|
| 150 |
+
x.size(0),
|
| 151 |
+
self.out_channels,
|
| 152 |
+
self.factor_t,
|
| 153 |
+
self.factor_s,
|
| 154 |
+
self.factor_s,
|
| 155 |
+
x.size(2),
|
| 156 |
+
x.size(3),
|
| 157 |
+
x.size(4),
|
| 158 |
+
)
|
| 159 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 160 |
+
x = x.view(
|
| 161 |
+
x.size(0),
|
| 162 |
+
self.out_channels,
|
| 163 |
+
x.size(2) * self.factor_t,
|
| 164 |
+
x.size(4) * self.factor_s,
|
| 165 |
+
x.size(6) * self.factor_s,
|
| 166 |
+
)
|
| 167 |
+
if first_chunk:
|
| 168 |
+
x = x[:, :, self.factor_t - 1:, :, :]
|
| 169 |
+
return x
|
| 170 |
+
else:
|
| 171 |
+
# print(1)
|
| 172 |
+
x = x.repeat_interleave(self.repeats // self.factor_t, dim=1)
|
| 173 |
+
x = x.view(
|
| 174 |
+
x.size(0),
|
| 175 |
+
self.out_channels,
|
| 176 |
+
1,
|
| 177 |
+
self.factor_s,
|
| 178 |
+
self.factor_s,
|
| 179 |
+
x.size(2),
|
| 180 |
+
x.size(3),
|
| 181 |
+
x.size(4),
|
| 182 |
+
)
|
| 183 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 184 |
+
x = x.view(
|
| 185 |
+
x.size(0),
|
| 186 |
+
self.out_channels,
|
| 187 |
+
x.size(2),
|
| 188 |
+
x.size(4) * self.factor_s,
|
| 189 |
+
x.size(6) * self.factor_s,
|
| 190 |
+
)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
class WanCausalConv3d(nn.Conv3d):
|
| 194 |
+
r"""
|
| 195 |
+
A custom 3D causal convolution layer with feature caching support.
|
| 196 |
+
|
| 197 |
+
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 198 |
+
caching for efficient inference.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
in_channels (int): Number of channels in the input image
|
| 202 |
+
out_channels (int): Number of channels produced by the convolution
|
| 203 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 204 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 205 |
+
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
in_channels: int,
|
| 211 |
+
out_channels: int,
|
| 212 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 213 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 214 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 215 |
+
) -> None:
|
| 216 |
+
super().__init__(
|
| 217 |
+
in_channels=in_channels,
|
| 218 |
+
out_channels=out_channels,
|
| 219 |
+
kernel_size=kernel_size,
|
| 220 |
+
stride=stride,
|
| 221 |
+
padding=padding,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Set up causal padding
|
| 225 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 226 |
+
self.padding = (0, 0, 0)
|
| 227 |
+
|
| 228 |
+
def forward(self, x, cache_x=None):
|
| 229 |
+
padding = list(self._padding)
|
| 230 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 231 |
+
cache_x = cache_x.to(x.device)
|
| 232 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 233 |
+
padding[4] -= cache_x.shape[2]
|
| 234 |
+
|
| 235 |
+
if any(padding):
|
| 236 |
+
x = F.pad(x, padding)
|
| 237 |
+
|
| 238 |
+
# print(x.shape)
|
| 239 |
+
return super().forward(x)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class WanRMS_norm(nn.Module):
|
| 243 |
+
r"""
|
| 244 |
+
A custom RMS normalization layer.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
dim (int): The number of dimensions to normalize over.
|
| 248 |
+
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
| 249 |
+
Default is True.
|
| 250 |
+
images (bool, optional): Whether the input represents image data. Default is True.
|
| 251 |
+
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, weight: bool = True, bias: bool = False) -> None:
|
| 255 |
+
super().__init__()
|
| 256 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 257 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 258 |
+
|
| 259 |
+
self.channel_first = channel_first
|
| 260 |
+
self.scale = dim**0.5
|
| 261 |
+
self.gamma = nn.Parameter(torch.ones(shape)) if weight else 1.0
|
| 262 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 263 |
+
|
| 264 |
+
def forward(self, x):
|
| 265 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class WanUpsample(nn.Upsample):
|
| 269 |
+
r"""
|
| 270 |
+
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
x (torch.Tensor): Input tensor to be upsampled.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def forward(self, x):
|
| 280 |
+
return super().forward(x.float()).type_as(x)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class WanResample(nn.Module):
|
| 284 |
+
r"""
|
| 285 |
+
A custom resampling module for 2D and 3D data.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
dim (int): The number of input/output channels.
|
| 289 |
+
mode (str): The resampling mode. Must be one of:
|
| 290 |
+
- 'none': No resampling (identity operation).
|
| 291 |
+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
| 292 |
+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
| 293 |
+
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 294 |
+
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.dim = dim
|
| 300 |
+
self.mode = mode
|
| 301 |
+
|
| 302 |
+
# default to dim //2
|
| 303 |
+
if upsample_out_dim is None:
|
| 304 |
+
upsample_out_dim = dim // 2
|
| 305 |
+
|
| 306 |
+
# layers
|
| 307 |
+
if mode == "upsample2d":
|
| 308 |
+
self.resample = nn.Sequential(
|
| 309 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
| 310 |
+
)
|
| 311 |
+
elif mode == "upsample3d":
|
| 312 |
+
self.resample = nn.Sequential(
|
| 313 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
| 314 |
+
)
|
| 315 |
+
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 316 |
+
|
| 317 |
+
elif mode == "downsample2d":
|
| 318 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 319 |
+
elif mode == "downsample3d":
|
| 320 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 321 |
+
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 322 |
+
|
| 323 |
+
else:
|
| 324 |
+
self.resample = nn.Identity()
|
| 325 |
+
|
| 326 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 327 |
+
b, c, t, h, w = x.size()
|
| 328 |
+
if self.mode == "upsample3d":
|
| 329 |
+
if feat_cache is not None:
|
| 330 |
+
idx = feat_idx[0]
|
| 331 |
+
if feat_cache[idx] is None:
|
| 332 |
+
feat_cache[idx] = "Rep"
|
| 333 |
+
feat_idx[0] += 1
|
| 334 |
+
else:
|
| 335 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 336 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 337 |
+
# cache last frame of last two chunk
|
| 338 |
+
cache_x = torch.cat(
|
| 339 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 340 |
+
)
|
| 341 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
| 342 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
| 343 |
+
if feat_cache[idx] == "Rep":
|
| 344 |
+
x = self.time_conv(x)
|
| 345 |
+
else:
|
| 346 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 347 |
+
feat_cache[idx] = cache_x
|
| 348 |
+
feat_idx[0] += 1
|
| 349 |
+
|
| 350 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 351 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 352 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 353 |
+
t = x.shape[2]
|
| 354 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 355 |
+
x = self.resample(x)
|
| 356 |
+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
| 357 |
+
|
| 358 |
+
if self.mode == "downsample3d":
|
| 359 |
+
if feat_cache is not None:
|
| 360 |
+
idx = feat_idx[0]
|
| 361 |
+
if feat_cache[idx] is None:
|
| 362 |
+
feat_cache[idx] = x.clone()
|
| 363 |
+
feat_idx[0] += 1
|
| 364 |
+
else:
|
| 365 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 366 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 367 |
+
feat_cache[idx] = cache_x
|
| 368 |
+
feat_idx[0] += 1
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class WanResidualBlock(nn.Module):
|
| 373 |
+
r"""
|
| 374 |
+
A custom residual block module.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
in_dim (int): Number of input channels.
|
| 378 |
+
out_dim (int): Number of output channels.
|
| 379 |
+
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 380 |
+
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
in_dim: int,
|
| 386 |
+
out_dim: int,
|
| 387 |
+
dropout: float = 0.0,
|
| 388 |
+
non_linearity: str = "silu",
|
| 389 |
+
) -> None:
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.in_dim = in_dim
|
| 392 |
+
self.out_dim = out_dim
|
| 393 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 394 |
+
|
| 395 |
+
# layers
|
| 396 |
+
self.norm1 = WanRMS_norm(in_dim, images=False)
|
| 397 |
+
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 398 |
+
self.norm2 = WanRMS_norm(out_dim, images=False)
|
| 399 |
+
self.dropout = nn.Dropout(dropout)
|
| 400 |
+
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
|
| 401 |
+
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 402 |
+
|
| 403 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 404 |
+
# Apply shortcut connection
|
| 405 |
+
h = self.conv_shortcut(x)
|
| 406 |
+
|
| 407 |
+
# First normalization and activation
|
| 408 |
+
x = self.norm1(x)
|
| 409 |
+
x = self.nonlinearity(x)
|
| 410 |
+
|
| 411 |
+
if feat_cache is not None:
|
| 412 |
+
idx = feat_idx[0]
|
| 413 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 414 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 415 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 416 |
+
|
| 417 |
+
x = self.conv1(x, feat_cache[idx])
|
| 418 |
+
feat_cache[idx] = cache_x
|
| 419 |
+
feat_idx[0] += 1
|
| 420 |
+
else:
|
| 421 |
+
x = self.conv1(x)
|
| 422 |
+
|
| 423 |
+
# Second normalization and activation
|
| 424 |
+
x = self.norm2(x)
|
| 425 |
+
x = self.nonlinearity(x)
|
| 426 |
+
|
| 427 |
+
# Dropout
|
| 428 |
+
x = self.dropout(x)
|
| 429 |
+
|
| 430 |
+
if feat_cache is not None:
|
| 431 |
+
idx = feat_idx[0]
|
| 432 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 433 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 434 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 435 |
+
|
| 436 |
+
x = self.conv2(x, feat_cache[idx])
|
| 437 |
+
feat_cache[idx] = cache_x
|
| 438 |
+
feat_idx[0] += 1
|
| 439 |
+
else:
|
| 440 |
+
x = self.conv2(x)
|
| 441 |
+
|
| 442 |
+
# Add residual connection
|
| 443 |
+
return h.add_(x)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class WanAttentionBlock(nn.Module):
|
| 447 |
+
r"""
|
| 448 |
+
Causal self-attention with a single head.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
dim (int): The number of channels in the input tensor.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(self, dim):
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.dim = dim
|
| 457 |
+
|
| 458 |
+
# layers
|
| 459 |
+
self.norm = WanRMS_norm(dim)
|
| 460 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 461 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 462 |
+
|
| 463 |
+
def forward(self, x):
|
| 464 |
+
identity = x
|
| 465 |
+
batch_size, channels, time, height, width = x.size()
|
| 466 |
+
|
| 467 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 468 |
+
x = self.norm(x)
|
| 469 |
+
|
| 470 |
+
# compute query, key, value
|
| 471 |
+
qkv = self.to_qkv(x)
|
| 472 |
+
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 473 |
+
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 474 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 475 |
+
|
| 476 |
+
# apply attention
|
| 477 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 478 |
+
|
| 479 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
| 480 |
+
|
| 481 |
+
# output projection
|
| 482 |
+
x = self.proj(x)
|
| 483 |
+
|
| 484 |
+
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
| 485 |
+
x = x.view(batch_size, time, channels, height, width)
|
| 486 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 487 |
+
|
| 488 |
+
return identity.add_(x)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class WanMidBlock(nn.Module):
|
| 492 |
+
"""
|
| 493 |
+
Middle block for WanVAE encoder and decoder.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
dim (int): Number of input/output channels.
|
| 497 |
+
dropout (float): Dropout rate.
|
| 498 |
+
non_linearity (str): Type of non-linearity to use.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 502 |
+
super().__init__()
|
| 503 |
+
self.dim = dim
|
| 504 |
+
|
| 505 |
+
# Create the components
|
| 506 |
+
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
| 507 |
+
attentions = []
|
| 508 |
+
for _ in range(num_layers):
|
| 509 |
+
attentions.append(WanAttentionBlock(dim))
|
| 510 |
+
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
| 511 |
+
self.attentions = nn.ModuleList(attentions)
|
| 512 |
+
self.resnets = nn.ModuleList(resnets)
|
| 513 |
+
|
| 514 |
+
self.gradient_checkpointing = False
|
| 515 |
+
|
| 516 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 517 |
+
# First residual block
|
| 518 |
+
x = self.resnets[0](x, feat_cache, feat_idx)
|
| 519 |
+
|
| 520 |
+
# Process through attention and residual blocks
|
| 521 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 522 |
+
if attn is not None:
|
| 523 |
+
x = attn(x)
|
| 524 |
+
|
| 525 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 526 |
+
|
| 527 |
+
return x
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class WanResidualDownBlock(nn.Module):
|
| 531 |
+
|
| 532 |
+
def __init__(self,
|
| 533 |
+
in_dim,
|
| 534 |
+
out_dim,
|
| 535 |
+
dropout,
|
| 536 |
+
num_res_blocks,
|
| 537 |
+
temperal_downsample=False,
|
| 538 |
+
down_flag=False):
|
| 539 |
+
super().__init__()
|
| 540 |
+
|
| 541 |
+
# Shortcut path with downsample
|
| 542 |
+
self.avg_shortcut = AvgDown3D(
|
| 543 |
+
in_dim,
|
| 544 |
+
out_dim,
|
| 545 |
+
factor_t=2 if temperal_downsample else 1,
|
| 546 |
+
factor_s=2 if down_flag else 1,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Main path with residual blocks and downsample
|
| 550 |
+
resnets = []
|
| 551 |
+
for _ in range(num_res_blocks):
|
| 552 |
+
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 553 |
+
in_dim = out_dim
|
| 554 |
+
self.resnets = nn.ModuleList(resnets)
|
| 555 |
+
|
| 556 |
+
# Add the final downsample block
|
| 557 |
+
if down_flag:
|
| 558 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 559 |
+
self.downsampler = WanResample(out_dim, mode=mode)
|
| 560 |
+
else:
|
| 561 |
+
self.downsampler = None
|
| 562 |
+
|
| 563 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 564 |
+
x_copy = x.clone()
|
| 565 |
+
for resnet in self.resnets:
|
| 566 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 567 |
+
if self.downsampler is not None:
|
| 568 |
+
x = self.downsampler(x, feat_cache, feat_idx)
|
| 569 |
+
|
| 570 |
+
return self.avg_shortcut(x_copy).add_(x)
|
| 571 |
+
|
| 572 |
+
class WanEncoder3d(nn.Module):
|
| 573 |
+
r"""
|
| 574 |
+
A 3D encoder module.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
dim (int): The base number of channels in the first layer.
|
| 578 |
+
z_dim (int): The dimensionality of the latent space.
|
| 579 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 580 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 581 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 582 |
+
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
| 583 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 584 |
+
non_linearity (str): Type of non-linearity to use.
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
def __init__(
|
| 588 |
+
self,
|
| 589 |
+
in_channels: int = 3,
|
| 590 |
+
dim=128,
|
| 591 |
+
z_dim=4,
|
| 592 |
+
dim_mult=[1, 2, 4, 4],
|
| 593 |
+
num_res_blocks=2,
|
| 594 |
+
attn_scales=[],
|
| 595 |
+
temperal_downsample=[True, True, False],
|
| 596 |
+
dropout=0.0,
|
| 597 |
+
non_linearity: str = "silu",
|
| 598 |
+
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
| 599 |
+
):
|
| 600 |
+
super().__init__()
|
| 601 |
+
self.dim = dim
|
| 602 |
+
self.z_dim = z_dim
|
| 603 |
+
self.dim_mult = dim_mult
|
| 604 |
+
self.num_res_blocks = num_res_blocks
|
| 605 |
+
self.attn_scales = attn_scales
|
| 606 |
+
self.temperal_downsample = temperal_downsample
|
| 607 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 608 |
+
|
| 609 |
+
# dimensions
|
| 610 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 611 |
+
scale = 1.0
|
| 612 |
+
|
| 613 |
+
# init block
|
| 614 |
+
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
| 615 |
+
|
| 616 |
+
# downsample blocks
|
| 617 |
+
self.down_blocks = nn.ModuleList([])
|
| 618 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 619 |
+
# residual (+attention) blocks
|
| 620 |
+
if is_residual:
|
| 621 |
+
self.down_blocks.append(
|
| 622 |
+
WanResidualDownBlock(
|
| 623 |
+
in_dim,
|
| 624 |
+
out_dim,
|
| 625 |
+
dropout,
|
| 626 |
+
num_res_blocks,
|
| 627 |
+
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
| 628 |
+
down_flag=i != len(dim_mult) - 1,
|
| 629 |
+
)
|
| 630 |
+
)
|
| 631 |
+
else:
|
| 632 |
+
for _ in range(num_res_blocks):
|
| 633 |
+
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 634 |
+
if scale in attn_scales:
|
| 635 |
+
self.down_blocks.append(WanAttentionBlock(out_dim))
|
| 636 |
+
in_dim = out_dim
|
| 637 |
+
|
| 638 |
+
# downsample block
|
| 639 |
+
if i != len(dim_mult) - 1:
|
| 640 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 641 |
+
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
| 642 |
+
scale /= 2.0
|
| 643 |
+
|
| 644 |
+
# middle blocks
|
| 645 |
+
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
| 646 |
+
|
| 647 |
+
# output blocks
|
| 648 |
+
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 649 |
+
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
| 650 |
+
|
| 651 |
+
self.gradient_checkpointing = False
|
| 652 |
+
|
| 653 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 654 |
+
assert x.shape[2] == 1
|
| 655 |
+
if feat_cache is not None:
|
| 656 |
+
idx = feat_idx[0]
|
| 657 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 658 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 659 |
+
# cache last frame of last two chunk
|
| 660 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 661 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 662 |
+
feat_cache[idx] = cache_x
|
| 663 |
+
feat_idx[0] += 1
|
| 664 |
+
else:
|
| 665 |
+
x = self.conv_in(x)
|
| 666 |
+
|
| 667 |
+
## downsamples
|
| 668 |
+
for layer in self.down_blocks:
|
| 669 |
+
if feat_cache is not None:
|
| 670 |
+
x = layer(x, feat_cache, feat_idx)
|
| 671 |
+
else:
|
| 672 |
+
x = layer(x)
|
| 673 |
+
|
| 674 |
+
## middle
|
| 675 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 676 |
+
|
| 677 |
+
## head
|
| 678 |
+
x = self.norm_out(x)
|
| 679 |
+
x = self.nonlinearity(x)
|
| 680 |
+
if feat_cache is not None:
|
| 681 |
+
idx = feat_idx[0]
|
| 682 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 683 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 684 |
+
# cache last frame of last two chunk
|
| 685 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 686 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 687 |
+
feat_cache[idx] = cache_x
|
| 688 |
+
feat_idx[0] += 1
|
| 689 |
+
else:
|
| 690 |
+
x = self.conv_out(x)
|
| 691 |
+
return x
|
| 692 |
+
|
| 693 |
+
class WanResidualUpBlock(nn.Module):
|
| 694 |
+
"""
|
| 695 |
+
A block that handles upsampling for the WanVAE decoder.
|
| 696 |
+
|
| 697 |
+
Args:
|
| 698 |
+
in_dim (int): Input dimension
|
| 699 |
+
out_dim (int): Output dimension
|
| 700 |
+
num_res_blocks (int): Number of residual blocks
|
| 701 |
+
dropout (float): Dropout rate
|
| 702 |
+
temperal_upsample (bool): Whether to upsample on temporal dimension
|
| 703 |
+
up_flag (bool): Whether to upsample or not
|
| 704 |
+
non_linearity (str): Type of non-linearity to use
|
| 705 |
+
"""
|
| 706 |
+
|
| 707 |
+
def __init__(
|
| 708 |
+
self,
|
| 709 |
+
in_dim: int,
|
| 710 |
+
out_dim: int,
|
| 711 |
+
num_res_blocks: int,
|
| 712 |
+
dropout: float = 0.0,
|
| 713 |
+
temperal_upsample: bool = False,
|
| 714 |
+
up_flag: bool = False,
|
| 715 |
+
non_linearity: str = "silu",
|
| 716 |
+
):
|
| 717 |
+
super().__init__()
|
| 718 |
+
self.in_dim = in_dim
|
| 719 |
+
self.out_dim = out_dim
|
| 720 |
+
|
| 721 |
+
if up_flag:
|
| 722 |
+
self.avg_shortcut = DupUp3D(
|
| 723 |
+
in_dim,
|
| 724 |
+
out_dim,
|
| 725 |
+
factor_t=2 if temperal_upsample else 1,
|
| 726 |
+
factor_s=2,
|
| 727 |
+
)
|
| 728 |
+
else:
|
| 729 |
+
self.avg_shortcut = None
|
| 730 |
+
|
| 731 |
+
# create residual blocks
|
| 732 |
+
resnets = []
|
| 733 |
+
current_dim = in_dim
|
| 734 |
+
for _ in range(num_res_blocks + 1):
|
| 735 |
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 736 |
+
current_dim = out_dim
|
| 737 |
+
|
| 738 |
+
self.resnets = nn.ModuleList(resnets)
|
| 739 |
+
|
| 740 |
+
# Add upsampling layer if needed
|
| 741 |
+
if up_flag:
|
| 742 |
+
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 743 |
+
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
| 744 |
+
else:
|
| 745 |
+
self.upsampler = None
|
| 746 |
+
|
| 747 |
+
self.gradient_checkpointing = False
|
| 748 |
+
|
| 749 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 750 |
+
"""
|
| 751 |
+
Forward pass through the upsampling block.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
x (torch.Tensor): Input tensor
|
| 755 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 756 |
+
feat_idx (list, optional): Feature index for cache management
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
torch.Tensor: Output tensor
|
| 760 |
+
"""
|
| 761 |
+
x_copy = x.clone()
|
| 762 |
+
|
| 763 |
+
for resnet in self.resnets:
|
| 764 |
+
if feat_cache is not None:
|
| 765 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 766 |
+
else:
|
| 767 |
+
x = resnet(x)
|
| 768 |
+
|
| 769 |
+
if self.upsampler is not None:
|
| 770 |
+
if feat_cache is not None:
|
| 771 |
+
x = self.upsampler(x, feat_cache, feat_idx)
|
| 772 |
+
else:
|
| 773 |
+
x = self.upsampler(x)
|
| 774 |
+
|
| 775 |
+
if self.avg_shortcut is not None:
|
| 776 |
+
# print(x.shape, x_copy.shape, self.avg_shortcut(x_copy, first_chunk=first_chunk).shape)
|
| 777 |
+
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
| 778 |
+
|
| 779 |
+
return x
|
| 780 |
+
|
| 781 |
+
class WanUpBlock(nn.Module):
|
| 782 |
+
"""
|
| 783 |
+
A block that handles upsampling for the WanVAE decoder.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
in_dim (int): Input dimension
|
| 787 |
+
out_dim (int): Output dimension
|
| 788 |
+
num_res_blocks (int): Number of residual blocks
|
| 789 |
+
dropout (float): Dropout rate
|
| 790 |
+
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 791 |
+
non_linearity (str): Type of non-linearity to use
|
| 792 |
+
"""
|
| 793 |
+
|
| 794 |
+
def __init__(
|
| 795 |
+
self,
|
| 796 |
+
in_dim: int,
|
| 797 |
+
out_dim: int,
|
| 798 |
+
num_res_blocks: int,
|
| 799 |
+
dropout: float = 0.0,
|
| 800 |
+
upsample_mode: Optional[str] = None,
|
| 801 |
+
non_linearity: str = "silu",
|
| 802 |
+
):
|
| 803 |
+
super().__init__()
|
| 804 |
+
self.in_dim = in_dim
|
| 805 |
+
self.out_dim = out_dim
|
| 806 |
+
|
| 807 |
+
# Create layers list
|
| 808 |
+
resnets = []
|
| 809 |
+
# Add residual blocks and attention if needed
|
| 810 |
+
current_dim = in_dim
|
| 811 |
+
for _ in range(num_res_blocks + 1):
|
| 812 |
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 813 |
+
current_dim = out_dim
|
| 814 |
+
|
| 815 |
+
self.resnets = nn.ModuleList(resnets)
|
| 816 |
+
|
| 817 |
+
# Add upsampling layer if needed
|
| 818 |
+
self.upsamplers = None
|
| 819 |
+
if upsample_mode is not None:
|
| 820 |
+
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
| 821 |
+
|
| 822 |
+
self.gradient_checkpointing = False
|
| 823 |
+
|
| 824 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
| 825 |
+
"""
|
| 826 |
+
Forward pass through the upsampling block.
|
| 827 |
+
|
| 828 |
+
Args:
|
| 829 |
+
x (torch.Tensor): Input tensor
|
| 830 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 831 |
+
feat_idx (list, optional): Feature index for cache management
|
| 832 |
+
|
| 833 |
+
Returns:
|
| 834 |
+
torch.Tensor: Output tensor
|
| 835 |
+
"""
|
| 836 |
+
for resnet in self.resnets:
|
| 837 |
+
if feat_cache is not None:
|
| 838 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 839 |
+
else:
|
| 840 |
+
x = resnet(x)
|
| 841 |
+
|
| 842 |
+
if self.upsamplers is not None:
|
| 843 |
+
if feat_cache is not None:
|
| 844 |
+
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
| 845 |
+
else:
|
| 846 |
+
x = self.upsamplers[0](x)
|
| 847 |
+
return x
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
class WanDecoder3d(nn.Module):
|
| 851 |
+
r"""
|
| 852 |
+
A 3D decoder module.
|
| 853 |
+
|
| 854 |
+
Args:
|
| 855 |
+
dim (int): The base number of channels in the first layer.
|
| 856 |
+
z_dim (int): The dimensionality of the latent space.
|
| 857 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 858 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 859 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 860 |
+
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
| 861 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 862 |
+
non_linearity (str): Type of non-linearity to use.
|
| 863 |
+
"""
|
| 864 |
+
|
| 865 |
+
def __init__(
|
| 866 |
+
self,
|
| 867 |
+
dim=128,
|
| 868 |
+
z_dim=4,
|
| 869 |
+
dim_mult=[1, 2, 4, 4],
|
| 870 |
+
num_res_blocks=2,
|
| 871 |
+
attn_scales=[],
|
| 872 |
+
temperal_upsample=[False, True, True],
|
| 873 |
+
dropout=0.0,
|
| 874 |
+
non_linearity: str = "silu",
|
| 875 |
+
out_channels: int = 3,
|
| 876 |
+
is_residual: bool = False,
|
| 877 |
+
):
|
| 878 |
+
super().__init__()
|
| 879 |
+
self.dim = dim
|
| 880 |
+
self.z_dim = z_dim
|
| 881 |
+
self.dim_mult = dim_mult
|
| 882 |
+
self.num_res_blocks = num_res_blocks
|
| 883 |
+
self.attn_scales = attn_scales
|
| 884 |
+
self.temperal_upsample = temperal_upsample
|
| 885 |
+
|
| 886 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 887 |
+
|
| 888 |
+
# dimensions
|
| 889 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 890 |
+
|
| 891 |
+
# init block
|
| 892 |
+
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 893 |
+
|
| 894 |
+
# middle blocks
|
| 895 |
+
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
| 896 |
+
|
| 897 |
+
# upsample blocks
|
| 898 |
+
self.up_blocks = nn.ModuleList([])
|
| 899 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 900 |
+
# residual (+attention) blocks
|
| 901 |
+
if i > 0 and not is_residual:
|
| 902 |
+
# wan vae 2.1
|
| 903 |
+
in_dim = in_dim // 2
|
| 904 |
+
|
| 905 |
+
# determine if we need upsampling
|
| 906 |
+
up_flag = i != len(dim_mult) - 1
|
| 907 |
+
# determine upsampling mode, if not upsampling, set to None
|
| 908 |
+
upsample_mode = None
|
| 909 |
+
if up_flag and temperal_upsample[i]:
|
| 910 |
+
upsample_mode = "upsample3d"
|
| 911 |
+
elif up_flag:
|
| 912 |
+
upsample_mode = "upsample2d"
|
| 913 |
+
# Create and add the upsampling block
|
| 914 |
+
if is_residual:
|
| 915 |
+
up_block = WanResidualUpBlock(
|
| 916 |
+
in_dim=in_dim,
|
| 917 |
+
out_dim=out_dim,
|
| 918 |
+
num_res_blocks=num_res_blocks,
|
| 919 |
+
dropout=dropout,
|
| 920 |
+
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
| 921 |
+
up_flag= up_flag,
|
| 922 |
+
non_linearity=non_linearity,
|
| 923 |
+
)
|
| 924 |
+
else:
|
| 925 |
+
up_block = WanUpBlock(
|
| 926 |
+
in_dim=in_dim,
|
| 927 |
+
out_dim=out_dim,
|
| 928 |
+
num_res_blocks=num_res_blocks,
|
| 929 |
+
dropout=dropout,
|
| 930 |
+
upsample_mode=upsample_mode,
|
| 931 |
+
non_linearity=non_linearity,
|
| 932 |
+
)
|
| 933 |
+
self.up_blocks.append(up_block)
|
| 934 |
+
|
| 935 |
+
# output blocks
|
| 936 |
+
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 937 |
+
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
| 938 |
+
|
| 939 |
+
self.gradient_checkpointing = False
|
| 940 |
+
|
| 941 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 942 |
+
assert x.shape[2] == 1
|
| 943 |
+
## conv1
|
| 944 |
+
if feat_cache is not None:
|
| 945 |
+
idx = feat_idx[0]
|
| 946 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 947 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 948 |
+
# cache last frame of last two chunk
|
| 949 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 950 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 951 |
+
feat_cache[idx] = cache_x
|
| 952 |
+
feat_idx[0] += 1
|
| 953 |
+
else:
|
| 954 |
+
x = self.conv_in(x)
|
| 955 |
+
|
| 956 |
+
## middle
|
| 957 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 958 |
+
|
| 959 |
+
## upsamples
|
| 960 |
+
for up_block in self.up_blocks:
|
| 961 |
+
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
|
| 962 |
+
|
| 963 |
+
## head
|
| 964 |
+
x = self.norm_out(x)
|
| 965 |
+
x = self.nonlinearity(x)
|
| 966 |
+
if feat_cache is not None:
|
| 967 |
+
idx = feat_idx[0]
|
| 968 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 969 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 970 |
+
# cache last frame of last two chunk
|
| 971 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 972 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 973 |
+
feat_cache[idx] = cache_x
|
| 974 |
+
feat_idx[0] += 1
|
| 975 |
+
else:
|
| 976 |
+
x = self.conv_out(x)
|
| 977 |
+
return x
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def patchify(x, patch_size):
|
| 981 |
+
# YiYi TODO: refactor this
|
| 982 |
+
from einops import rearrange
|
| 983 |
+
if patch_size == 1:
|
| 984 |
+
return x
|
| 985 |
+
if x.dim() == 4:
|
| 986 |
+
x = rearrange(
|
| 987 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 988 |
+
elif x.dim() == 5:
|
| 989 |
+
x = rearrange(
|
| 990 |
+
x,
|
| 991 |
+
"b c f (h q) (w r) -> b (c r q) f h w",
|
| 992 |
+
q=patch_size,
|
| 993 |
+
r=patch_size,
|
| 994 |
+
)
|
| 995 |
+
else:
|
| 996 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 997 |
+
|
| 998 |
+
return x
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
def unpatchify(x, patch_size):
|
| 1002 |
+
# YiYi TODO: refactor this
|
| 1003 |
+
from einops import rearrange
|
| 1004 |
+
if patch_size == 1:
|
| 1005 |
+
return x
|
| 1006 |
+
|
| 1007 |
+
if x.dim() == 4:
|
| 1008 |
+
x = rearrange(
|
| 1009 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 1010 |
+
elif x.dim() == 5:
|
| 1011 |
+
x = rearrange(
|
| 1012 |
+
x,
|
| 1013 |
+
"b (c r q) f h w -> b c f (h q) (w r)",
|
| 1014 |
+
q=patch_size,
|
| 1015 |
+
r=patch_size,
|
| 1016 |
+
)
|
| 1017 |
+
return x
|
| 1018 |
+
|
| 1019 |
+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 1020 |
+
r"""
|
| 1021 |
+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 1022 |
+
Introduced in [Wan 2.1].
|
| 1023 |
+
|
| 1024 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 1025 |
+
for all models (such as downloading or saving).
|
| 1026 |
+
"""
|
| 1027 |
+
|
| 1028 |
+
_supports_gradient_checkpointing = False
|
| 1029 |
+
|
| 1030 |
+
@register_to_config
|
| 1031 |
+
def __init__(
|
| 1032 |
+
self,
|
| 1033 |
+
base_dim: int = 96,
|
| 1034 |
+
decoder_base_dim: Optional[int] = None,
|
| 1035 |
+
z_dim: int = 16,
|
| 1036 |
+
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
| 1037 |
+
num_res_blocks: int = 2,
|
| 1038 |
+
attn_scales: List[float] = [],
|
| 1039 |
+
temperal_downsample: List[bool] = [False, True, True],
|
| 1040 |
+
dropout: float = 0.0,
|
| 1041 |
+
latents_mean: List[float] = [
|
| 1042 |
+
-0.7571,
|
| 1043 |
+
-0.7089,
|
| 1044 |
+
-0.9113,
|
| 1045 |
+
0.1075,
|
| 1046 |
+
-0.1745,
|
| 1047 |
+
0.9653,
|
| 1048 |
+
-0.1517,
|
| 1049 |
+
1.5508,
|
| 1050 |
+
0.4134,
|
| 1051 |
+
-0.0715,
|
| 1052 |
+
0.5517,
|
| 1053 |
+
-0.3632,
|
| 1054 |
+
-0.1922,
|
| 1055 |
+
-0.9497,
|
| 1056 |
+
0.2503,
|
| 1057 |
+
-0.2921,
|
| 1058 |
+
],
|
| 1059 |
+
latents_std: List[float] = [
|
| 1060 |
+
2.8184,
|
| 1061 |
+
1.4541,
|
| 1062 |
+
2.3275,
|
| 1063 |
+
2.6558,
|
| 1064 |
+
1.2196,
|
| 1065 |
+
1.7708,
|
| 1066 |
+
2.6052,
|
| 1067 |
+
2.0743,
|
| 1068 |
+
3.2687,
|
| 1069 |
+
2.1526,
|
| 1070 |
+
2.8652,
|
| 1071 |
+
1.5579,
|
| 1072 |
+
1.6382,
|
| 1073 |
+
1.1253,
|
| 1074 |
+
2.8251,
|
| 1075 |
+
1.9160,
|
| 1076 |
+
],
|
| 1077 |
+
is_residual: bool = False,
|
| 1078 |
+
in_channels: int = 3,
|
| 1079 |
+
out_channels: int = 3,
|
| 1080 |
+
patch_size: Optional[int] = None,
|
| 1081 |
+
scale_factor_temporal: Optional[int] = 4,
|
| 1082 |
+
scale_factor_spatial: Optional[int] = 8,
|
| 1083 |
+
clip_output: bool = True,
|
| 1084 |
+
) -> None:
|
| 1085 |
+
super().__init__()
|
| 1086 |
+
|
| 1087 |
+
self.z_dim = z_dim
|
| 1088 |
+
self.temperal_downsample = temperal_downsample
|
| 1089 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 1090 |
+
|
| 1091 |
+
if decoder_base_dim is None:
|
| 1092 |
+
decoder_base_dim = base_dim
|
| 1093 |
+
|
| 1094 |
+
self.encoder = WanEncoder3d(
|
| 1095 |
+
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
|
| 1096 |
+
)
|
| 1097 |
+
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 1098 |
+
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
| 1099 |
+
|
| 1100 |
+
self.decoder = WanDecoder3d(
|
| 1101 |
+
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
| 1105 |
+
|
| 1106 |
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
| 1107 |
+
# to perform decoding of a single video latent at a time.
|
| 1108 |
+
self.use_slicing = False
|
| 1109 |
+
|
| 1110 |
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
| 1111 |
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
| 1112 |
+
# intermediate tiles together, the memory requirement can be lowered.
|
| 1113 |
+
self.use_tiling = False
|
| 1114 |
+
|
| 1115 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 1116 |
+
self.tile_sample_min_height = 256
|
| 1117 |
+
self.tile_sample_min_width = 256
|
| 1118 |
+
|
| 1119 |
+
# The minimal distance between two spatial tiles
|
| 1120 |
+
self.tile_sample_stride_height = 192
|
| 1121 |
+
self.tile_sample_stride_width = 192
|
| 1122 |
+
|
| 1123 |
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
| 1124 |
+
self._cached_conv_counts = {
|
| 1125 |
+
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
| 1126 |
+
if self.decoder is not None
|
| 1127 |
+
else 0,
|
| 1128 |
+
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
| 1129 |
+
if self.encoder is not None
|
| 1130 |
+
else 0,
|
| 1131 |
+
}
|
| 1132 |
+
|
| 1133 |
+
def enable_tiling(
|
| 1134 |
+
self,
|
| 1135 |
+
tile_sample_min_height: Optional[int] = None,
|
| 1136 |
+
tile_sample_min_width: Optional[int] = None,
|
| 1137 |
+
tile_sample_stride_height: Optional[float] = None,
|
| 1138 |
+
tile_sample_stride_width: Optional[float] = None,
|
| 1139 |
+
) -> None:
|
| 1140 |
+
r"""
|
| 1141 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 1142 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 1143 |
+
processing larger images.
|
| 1144 |
+
|
| 1145 |
+
Args:
|
| 1146 |
+
tile_sample_min_height (`int`, *optional*):
|
| 1147 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 1148 |
+
tile_sample_min_width (`int`, *optional*):
|
| 1149 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 1150 |
+
tile_sample_stride_height (`int`, *optional*):
|
| 1151 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 1152 |
+
no tiling artifacts produced across the height dimension.
|
| 1153 |
+
tile_sample_stride_width (`int`, *optional*):
|
| 1154 |
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
| 1155 |
+
artifacts produced across the width dimension.
|
| 1156 |
+
"""
|
| 1157 |
+
self.use_tiling = True
|
| 1158 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 1159 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 1160 |
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
| 1161 |
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
| 1162 |
+
|
| 1163 |
+
def disable_tiling(self) -> None:
|
| 1164 |
+
r"""
|
| 1165 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 1166 |
+
decoding in one step.
|
| 1167 |
+
"""
|
| 1168 |
+
self.use_tiling = False
|
| 1169 |
+
|
| 1170 |
+
def enable_slicing(self) -> None:
|
| 1171 |
+
r"""
|
| 1172 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 1173 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 1174 |
+
"""
|
| 1175 |
+
self.use_slicing = True
|
| 1176 |
+
|
| 1177 |
+
def disable_slicing(self) -> None:
|
| 1178 |
+
r"""
|
| 1179 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 1180 |
+
decoding in one step.
|
| 1181 |
+
"""
|
| 1182 |
+
self.use_slicing = False
|
| 1183 |
+
|
| 1184 |
+
def clear_cache(self):
|
| 1185 |
+
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
| 1186 |
+
self._conv_num = self._cached_conv_counts["decoder"]
|
| 1187 |
+
self._conv_idx = [0]
|
| 1188 |
+
self._feat_map = [None] * self._conv_num
|
| 1189 |
+
# cache encode
|
| 1190 |
+
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
| 1191 |
+
self._enc_conv_idx = [0]
|
| 1192 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 1193 |
+
|
| 1194 |
+
def _encode(self, x: torch.Tensor):
|
| 1195 |
+
_, _, num_frame, height, width = x.shape
|
| 1196 |
+
|
| 1197 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 1198 |
+
return self.tiled_encode(x)
|
| 1199 |
+
|
| 1200 |
+
self.clear_cache()
|
| 1201 |
+
if self.config.patch_size is not None:
|
| 1202 |
+
x = patchify(x, patch_size=self.config.patch_size)
|
| 1203 |
+
iter_ = 1 + (num_frame - 1) // 4
|
| 1204 |
+
self._enc_feat_map = None if iter_ == 1 else self._enc_feat_map
|
| 1205 |
+
for i in range(iter_):
|
| 1206 |
+
self._enc_conv_idx = [0]
|
| 1207 |
+
if i == 0:
|
| 1208 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 1209 |
+
else:
|
| 1210 |
+
out_ = self.encoder(
|
| 1211 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 1212 |
+
feat_cache=self._enc_feat_map,
|
| 1213 |
+
feat_idx=self._enc_conv_idx,
|
| 1214 |
+
)
|
| 1215 |
+
out = torch.cat([out, out_], 2)
|
| 1216 |
+
|
| 1217 |
+
enc = self.quant_conv(out)
|
| 1218 |
+
self.clear_cache()
|
| 1219 |
+
return enc
|
| 1220 |
+
|
| 1221 |
+
@apply_forward_hook
|
| 1222 |
+
def encode(
|
| 1223 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1224 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1225 |
+
r"""
|
| 1226 |
+
Encode a batch of images into latents.
|
| 1227 |
+
|
| 1228 |
+
Args:
|
| 1229 |
+
x (`torch.Tensor`): Input batch of images.
|
| 1230 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1231 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 1232 |
+
|
| 1233 |
+
Returns:
|
| 1234 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 1235 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 1236 |
+
"""
|
| 1237 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 1238 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 1239 |
+
h = torch.cat(encoded_slices)
|
| 1240 |
+
else:
|
| 1241 |
+
h = self._encode(x)
|
| 1242 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1243 |
+
|
| 1244 |
+
if not return_dict:
|
| 1245 |
+
return (posterior,)
|
| 1246 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1247 |
+
|
| 1248 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
| 1249 |
+
_, _, num_frame, height, width = z.shape
|
| 1250 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1251 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1252 |
+
|
| 1253 |
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
| 1254 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 1255 |
+
|
| 1256 |
+
self.clear_cache()
|
| 1257 |
+
self._feat_map = None if num_frame == 1 else self._feat_map
|
| 1258 |
+
x = self.post_quant_conv(z)
|
| 1259 |
+
for i in range(num_frame):
|
| 1260 |
+
self._conv_idx = [0]
|
| 1261 |
+
if i == 0:
|
| 1262 |
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
|
| 1263 |
+
else:
|
| 1264 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1265 |
+
out = torch.cat([out, out_], 2)
|
| 1266 |
+
|
| 1267 |
+
if self.config.clip_output:
|
| 1268 |
+
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 1269 |
+
if self.config.patch_size is not None:
|
| 1270 |
+
out = unpatchify(out, patch_size=self.config.patch_size)
|
| 1271 |
+
self.clear_cache()
|
| 1272 |
+
if not return_dict:
|
| 1273 |
+
return (out,)
|
| 1274 |
+
|
| 1275 |
+
return DecoderOutput(sample=out)
|
| 1276 |
+
|
| 1277 |
+
@apply_forward_hook
|
| 1278 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1279 |
+
r"""
|
| 1280 |
+
Decode a batch of images.
|
| 1281 |
+
|
| 1282 |
+
Args:
|
| 1283 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1284 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1285 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1286 |
+
|
| 1287 |
+
Returns:
|
| 1288 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1289 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1290 |
+
returned.
|
| 1291 |
+
"""
|
| 1292 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 1293 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 1294 |
+
decoded = torch.cat(decoded_slices)
|
| 1295 |
+
else:
|
| 1296 |
+
decoded = self._decode(z).sample
|
| 1297 |
+
|
| 1298 |
+
if not return_dict:
|
| 1299 |
+
return (decoded,)
|
| 1300 |
+
return DecoderOutput(sample=decoded)
|
| 1301 |
+
|
| 1302 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1303 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 1304 |
+
for y in range(blend_extent):
|
| 1305 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 1306 |
+
y / blend_extent
|
| 1307 |
+
)
|
| 1308 |
+
return b
|
| 1309 |
+
|
| 1310 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1311 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 1312 |
+
for x in range(blend_extent):
|
| 1313 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 1314 |
+
x / blend_extent
|
| 1315 |
+
)
|
| 1316 |
+
return b
|
| 1317 |
+
|
| 1318 |
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 1319 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 1320 |
+
|
| 1321 |
+
Args:
|
| 1322 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 1323 |
+
|
| 1324 |
+
Returns:
|
| 1325 |
+
`torch.Tensor`:
|
| 1326 |
+
The latent representation of the encoded videos.
|
| 1327 |
+
"""
|
| 1328 |
+
_, _, num_frames, height, width = x.shape
|
| 1329 |
+
latent_height = height // self.spatial_compression_ratio
|
| 1330 |
+
latent_width = width // self.spatial_compression_ratio
|
| 1331 |
+
|
| 1332 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1333 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1334 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1335 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1336 |
+
|
| 1337 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 1338 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 1339 |
+
|
| 1340 |
+
# Split x into overlapping tiles and encode them separately.
|
| 1341 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1342 |
+
rows = []
|
| 1343 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 1344 |
+
row = []
|
| 1345 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 1346 |
+
self.clear_cache()
|
| 1347 |
+
time = []
|
| 1348 |
+
frame_range = 1 + (num_frames - 1) // 4
|
| 1349 |
+
for k in range(frame_range):
|
| 1350 |
+
self._enc_conv_idx = [0]
|
| 1351 |
+
if k == 0:
|
| 1352 |
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
| 1353 |
+
else:
|
| 1354 |
+
tile = x[
|
| 1355 |
+
:,
|
| 1356 |
+
:,
|
| 1357 |
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
| 1358 |
+
i : i + self.tile_sample_min_height,
|
| 1359 |
+
j : j + self.tile_sample_min_width,
|
| 1360 |
+
]
|
| 1361 |
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 1362 |
+
tile = self.quant_conv(tile)
|
| 1363 |
+
time.append(tile)
|
| 1364 |
+
row.append(torch.cat(time, dim=2))
|
| 1365 |
+
rows.append(row)
|
| 1366 |
+
self.clear_cache()
|
| 1367 |
+
|
| 1368 |
+
result_rows = []
|
| 1369 |
+
for i, row in enumerate(rows):
|
| 1370 |
+
result_row = []
|
| 1371 |
+
for j, tile in enumerate(row):
|
| 1372 |
+
# blend the above tile and the left tile
|
| 1373 |
+
# to the current tile and add the current tile to the result row
|
| 1374 |
+
if i > 0:
|
| 1375 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1376 |
+
if j > 0:
|
| 1377 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1378 |
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
| 1379 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1380 |
+
|
| 1381 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 1382 |
+
return enc
|
| 1383 |
+
|
| 1384 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1385 |
+
r"""
|
| 1386 |
+
Decode a batch of images using a tiled decoder.
|
| 1387 |
+
|
| 1388 |
+
Args:
|
| 1389 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1390 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1391 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1392 |
+
|
| 1393 |
+
Returns:
|
| 1394 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1395 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1396 |
+
returned.
|
| 1397 |
+
"""
|
| 1398 |
+
_, _, num_frames, height, width = z.shape
|
| 1399 |
+
sample_height = height * self.spatial_compression_ratio
|
| 1400 |
+
sample_width = width * self.spatial_compression_ratio
|
| 1401 |
+
|
| 1402 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1403 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1404 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1405 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1406 |
+
|
| 1407 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 1408 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 1409 |
+
|
| 1410 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1411 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1412 |
+
rows = []
|
| 1413 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 1414 |
+
row = []
|
| 1415 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 1416 |
+
self.clear_cache()
|
| 1417 |
+
time = []
|
| 1418 |
+
for k in range(num_frames):
|
| 1419 |
+
self._conv_idx = [0]
|
| 1420 |
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
| 1421 |
+
tile = self.post_quant_conv(tile)
|
| 1422 |
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1423 |
+
time.append(decoded)
|
| 1424 |
+
row.append(torch.cat(time, dim=2))
|
| 1425 |
+
rows.append(row)
|
| 1426 |
+
self.clear_cache()
|
| 1427 |
+
|
| 1428 |
+
result_rows = []
|
| 1429 |
+
for i, row in enumerate(rows):
|
| 1430 |
+
result_row = []
|
| 1431 |
+
for j, tile in enumerate(row):
|
| 1432 |
+
# blend the above tile and the left tile
|
| 1433 |
+
# to the current tile and add the current tile to the result row
|
| 1434 |
+
if i > 0:
|
| 1435 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1436 |
+
if j > 0:
|
| 1437 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1438 |
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
| 1439 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1440 |
+
|
| 1441 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 1442 |
+
|
| 1443 |
+
if not return_dict:
|
| 1444 |
+
return (dec,)
|
| 1445 |
+
return DecoderOutput(sample=dec)
|
| 1446 |
+
|
| 1447 |
+
def forward(
|
| 1448 |
+
self,
|
| 1449 |
+
sample: torch.Tensor,
|
| 1450 |
+
sample_posterior: bool = False,
|
| 1451 |
+
return_dict: bool = True,
|
| 1452 |
+
generator: Optional[torch.Generator] = None,
|
| 1453 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 1454 |
+
"""
|
| 1455 |
+
Args:
|
| 1456 |
+
sample (`torch.Tensor`): Input sample.
|
| 1457 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1458 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 1459 |
+
"""
|
| 1460 |
+
x = sample
|
| 1461 |
+
posterior = self.encode(x).latent_dist
|
| 1462 |
+
if sample_posterior:
|
| 1463 |
+
z = posterior.sample(generator=generator)
|
| 1464 |
+
else:
|
| 1465 |
+
z = posterior.mode()
|
| 1466 |
+
dec = self.decode(z, return_dict=return_dict)
|
| 1467 |
+
return dec
|
models/reconstruction_model.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from utils import zero_init, EMANorm, create_rays
|
| 9 |
+
|
| 10 |
+
import einops
|
| 11 |
+
|
| 12 |
+
from .render import gaussian_render
|
| 13 |
+
|
| 14 |
+
from utils import quaternion_to_matrix
|
| 15 |
+
|
| 16 |
+
def inverse_sigmoid(x):
|
| 17 |
+
if type(x) == torch.Tensor:
|
| 18 |
+
return torch.log(x/(1-x))
|
| 19 |
+
else:
|
| 20 |
+
return math.log(x/(1-x))
|
| 21 |
+
|
| 22 |
+
def inverse_softplus(x, beta=1):
|
| 23 |
+
if type(x) == torch.Tensor:
|
| 24 |
+
return (torch.exp(beta * x) - 1).log() / beta
|
| 25 |
+
else:
|
| 26 |
+
return math.log((math.exp(beta * x) - 1)) / beta
|
| 27 |
+
|
| 28 |
+
import copy
|
| 29 |
+
|
| 30 |
+
import math
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
from .autoencoder_kl_wan import WanCausalConv3d, WanRMS_norm, unpatchify
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class WANDecoderPixelAligned3DGSReconstructionModel(nn.Module):
|
| 39 |
+
def __init__(self,
|
| 40 |
+
vae_model,
|
| 41 |
+
feat_dim,
|
| 42 |
+
# num_remove_decoder_up_blocks=0,
|
| 43 |
+
# num_points_per_pixel=4,
|
| 44 |
+
use_network_checkpointing=True,
|
| 45 |
+
use_render_checkpointing=True
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.decoder = copy.deepcopy(vae_model.decoder).requires_grad_(True)
|
| 50 |
+
self.post_quant_conv = copy.deepcopy(vae_model.post_quant_conv).requires_grad_(True)
|
| 51 |
+
|
| 52 |
+
self.extra_conv_in = WanCausalConv3d(feat_dim, self.decoder.conv_in.weight.shape[0], 3, padding=1)
|
| 53 |
+
|
| 54 |
+
time_pad = self.extra_conv_in._padding[4]
|
| 55 |
+
self.extra_conv_in.padding = (0, self.extra_conv_in._padding[2], self.extra_conv_in._padding[0])
|
| 56 |
+
self.extra_conv_in._padding = (0, 0, 0, 0, 0, 0)
|
| 57 |
+
self.extra_conv_in.weight = torch.nn.Parameter(self.extra_conv_in.weight[:, :, time_pad:].clone())
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
self.extra_conv_in.weight.data.zero_()
|
| 61 |
+
self.extra_conv_in.bias.data.zero_()
|
| 62 |
+
|
| 63 |
+
# remove one block
|
| 64 |
+
# self.decoder.up_blocks = self.decoder.up_blocks[:-1]
|
| 65 |
+
dims = [self.decoder.dim * u for u in [self.decoder.dim_mult[-1]] + self.decoder.dim_mult[::-1]]
|
| 66 |
+
# self.decoder.up_blocks[-1].upsampler.mode = None
|
| 67 |
+
# self.decoder.up_blocks[-1].upsampler.resample = nn.Identity()
|
| 68 |
+
# self.decoder.up_blocks[-1].avg_shortcut = None
|
| 69 |
+
|
| 70 |
+
self.decoder.norm_out = WanRMS_norm(dims[-1], images=False, bias=False)
|
| 71 |
+
self.decoder.conv_out = nn.Identity()
|
| 72 |
+
|
| 73 |
+
# add ema_norm for vae
|
| 74 |
+
# for i_level in reversed(range(len(self.decoder.up_blocks))):
|
| 75 |
+
# if self.decoder.up_blocks[i_level].upsampler is not None:
|
| 76 |
+
# self.decoder.up_blocks[i_level].upsampler.resample = nn.Sequential(
|
| 77 |
+
# self.decoder.up_blocks[i_level].upsampler.resample,
|
| 78 |
+
# )
|
| 79 |
+
|
| 80 |
+
self.patch_size = vae_model.config.patch_size
|
| 81 |
+
# assert dims[-1] % 4 == 0
|
| 82 |
+
self.gs_head = PixelAligned3DGS(dims[-1], num_points_per_pixel=2)
|
| 83 |
+
|
| 84 |
+
del self.decoder.up_blocks[0].upsampler.time_conv
|
| 85 |
+
del self.decoder.up_blocks[1].upsampler.time_conv
|
| 86 |
+
|
| 87 |
+
self.decoder.conv_out = nn.Identity()
|
| 88 |
+
|
| 89 |
+
self.network_checkpointing = use_network_checkpointing
|
| 90 |
+
self.render_checkpointing = use_render_checkpointing
|
| 91 |
+
|
| 92 |
+
def decode(self, feats, z):
|
| 93 |
+
## conv1
|
| 94 |
+
x = self.decoder.conv_in(self.post_quant_conv(z)) + self.extra_conv_in(feats)
|
| 95 |
+
|
| 96 |
+
## middle
|
| 97 |
+
if self.network_checkpointing and torch.is_grad_enabled():
|
| 98 |
+
x = torch.utils.checkpoint.checkpoint(self.decoder.mid_block, x, None, [0], use_reentrant=False)
|
| 99 |
+
else:
|
| 100 |
+
x = self.decoder.mid_block(x, None, [0])
|
| 101 |
+
|
| 102 |
+
## upsamples
|
| 103 |
+
for i, up_block in enumerate(self.decoder.up_blocks):
|
| 104 |
+
if self.network_checkpointing and torch.is_grad_enabled():
|
| 105 |
+
x = torch.utils.checkpoint.checkpoint(up_block, x, None, [0], True, use_reentrant=False)
|
| 106 |
+
else:
|
| 107 |
+
x = up_block(x, None, [0], first_chunk=True)
|
| 108 |
+
|
| 109 |
+
# head
|
| 110 |
+
x = self.decoder.norm_out(x)
|
| 111 |
+
x = self.decoder.nonlinearity(x)
|
| 112 |
+
x = self.decoder.conv_out(x)
|
| 113 |
+
|
| 114 |
+
# if self.patch_size is not None:
|
| 115 |
+
# x = unpatchify(x, patch_size=self.patch_size)
|
| 116 |
+
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
def forward(self, feats, z, cameras):
|
| 120 |
+
|
| 121 |
+
x = self.decode(feats, z).squeeze(2)
|
| 122 |
+
|
| 123 |
+
gaussian_params = self.gs_head(x, cameras.flatten(0, 1)).unflatten(0, (cameras.shape[0], cameras.shape[1]))
|
| 124 |
+
|
| 125 |
+
return gaussian_params
|
| 126 |
+
|
| 127 |
+
# def forward(self, images, cameras, scene_chunk_lens):
|
| 128 |
+
|
| 129 |
+
# x, z, feats = self.encode(images)
|
| 130 |
+
|
| 131 |
+
# return self.reconstruct(x, z, feats, cameras, scene_chunk_lens)
|
| 132 |
+
|
| 133 |
+
@torch.amp.autocast(device_type='cuda', enabled=False)
|
| 134 |
+
def render(self, gaussian_params, camerass, height, width, bg_mode='random'):
|
| 135 |
+
|
| 136 |
+
camerass = camerass.to(torch.float32)
|
| 137 |
+
|
| 138 |
+
test_c2ws = torch.eye(4, device=camerass.device)[None][None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
|
| 139 |
+
test_c2ws[:, :, :3, :3] = quaternion_to_matrix(camerass[:, :, :4])
|
| 140 |
+
test_c2ws[:, :, :3, 3] = camerass[:, :, 4:7]
|
| 141 |
+
|
| 142 |
+
test_intr = torch.eye(3, device=camerass.device)[None, None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
|
| 143 |
+
fx, fy, cx, cy = camerass[:, :, 7:11].split([1, 1, 1, 1], dim=-1)
|
| 144 |
+
|
| 145 |
+
test_intr = torch.cat([fx * width, fy * height, cx * width, cy * height], dim=-1)
|
| 146 |
+
|
| 147 |
+
return gaussian_render(gaussian_params, test_c2ws, test_intr, width, height, use_checkpoint=self.render_checkpointing, sh_degree=self.gs_head.sh_degree, bg_mode=bg_mode)
|
| 148 |
+
|
| 149 |
+
from torch.autograd import Function
|
| 150 |
+
|
| 151 |
+
class _trunc_exp(Function):
|
| 152 |
+
@staticmethod
|
| 153 |
+
def forward(ctx, x):
|
| 154 |
+
ctx.save_for_backward(x)
|
| 155 |
+
return torch.exp(x)
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def backward(ctx, g):
|
| 159 |
+
x = ctx.saved_tensors[0]
|
| 160 |
+
return g * torch.exp(x.clamp(-10, 10))
|
| 161 |
+
|
| 162 |
+
trunc_exp = _trunc_exp.apply
|
| 163 |
+
|
| 164 |
+
class PixelAligned3DGS(nn.Module):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
embed_dim,
|
| 168 |
+
sh_degree=2,
|
| 169 |
+
use_mask=False,
|
| 170 |
+
scale_range=(0, 16), # related to pixel size
|
| 171 |
+
num_points_per_pixel=1,
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.sh_degree = sh_degree
|
| 176 |
+
|
| 177 |
+
# sh, uv_offset, depth, opacity, scales, rotations
|
| 178 |
+
# TODO: handle different sh_degree
|
| 179 |
+
self.gaussian_channels = [3 * (self.sh_degree + 1) ** 2, 2, 1, 1, 3, 4, (1 if use_mask else 0)]
|
| 180 |
+
|
| 181 |
+
self.gs_proj = nn.Conv2d(embed_dim, num_points_per_pixel * sum(self.gaussian_channels), 3, 1, 1)
|
| 182 |
+
self.register_buffer("lrs_mul", torch.Tensor(
|
| 183 |
+
[1] * 3 + # sh 0
|
| 184 |
+
[0.5] * 3 * ((self.sh_degree + 1) ** 2 - 1) + # other sh
|
| 185 |
+
[0.01] * 2 + # uv_offset
|
| 186 |
+
[1] * 1 + # depth
|
| 187 |
+
[1] * 1 + # opacity
|
| 188 |
+
[1] * 3 + # scales
|
| 189 |
+
[1] * 4 + # rotations
|
| 190 |
+
[0.1] * (1 if use_mask else 0) # mask
|
| 191 |
+
).repeat(num_points_per_pixel), persistent=True)
|
| 192 |
+
|
| 193 |
+
self.lrs_mul = self.lrs_mul / self.lrs_mul.max()
|
| 194 |
+
|
| 195 |
+
self.use_mask = use_mask
|
| 196 |
+
|
| 197 |
+
self.scale_range = scale_range
|
| 198 |
+
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
self.gs_proj.weight.data.zero_()
|
| 201 |
+
self.gs_proj.bias = nn.Parameter(torch.Tensor(
|
| 202 |
+
[0.0] * 3 * (self.sh_degree + 1) ** 2 + # sh
|
| 203 |
+
[0.0] * 2 + # uv_offset
|
| 204 |
+
[math.log(1)] * 1 + # depth
|
| 205 |
+
# [inverse_softplus(1)] * 1 + # depth
|
| 206 |
+
[inverse_sigmoid(0.1)] * 1 + # opacity
|
| 207 |
+
[inverse_sigmoid((1 - scale_range[0]) / (scale_range[1] - scale_range[0]))] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
|
| 208 |
+
# [inverse_softplus(0.005)] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
|
| 209 |
+
[1., 0, 0, 0] + # rotations
|
| 210 |
+
[inverse_sigmoid(0.9)] * (1 if use_mask else 0) # mask (default: 0.9)
|
| 211 |
+
).repeat(num_points_per_pixel) / self.lrs_mul)
|
| 212 |
+
|
| 213 |
+
self.num_points_per_pixel = num_points_per_pixel
|
| 214 |
+
|
| 215 |
+
@torch.amp.autocast(device_type='cuda', enabled=False)
|
| 216 |
+
def forward(self, x, cameras):
|
| 217 |
+
|
| 218 |
+
x = x.to(torch.float32)
|
| 219 |
+
cameras = cameras.to(torch.float32)
|
| 220 |
+
|
| 221 |
+
BN, _, h, w = x.shape
|
| 222 |
+
|
| 223 |
+
local_gaussian_params = F.conv2d(x, self.gs_proj.weight * self.lrs_mul[:, None, None, None], self.gs_proj.bias * self.lrs_mul, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))
|
| 224 |
+
# local_gaussian_params = F.conv2d(x, self.gs_proj.weight, self.gs_proj.bias, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))
|
| 225 |
+
|
| 226 |
+
# batch * n_frame, num_points_per_pixel, c, h, w -> batch * n_frame, num_points_per_pixel, h, w, c
|
| 227 |
+
local_gaussian_params = local_gaussian_params.permute(0, 1, 3, 4, 2)
|
| 228 |
+
|
| 229 |
+
features, uv_offset, depth, opacity, scales, rotations, mask = local_gaussian_params.split(self.gaussian_channels, dim=-1)
|
| 230 |
+
|
| 231 |
+
rays_o, rays_d = create_rays(cameras[:, None].repeat(1, self.num_points_per_pixel, 1), uv_offset=uv_offset, h=h, w=w)
|
| 232 |
+
|
| 233 |
+
depth = trunc_exp(depth)
|
| 234 |
+
# depth = F.softplus(depth, beta=1)
|
| 235 |
+
xyz = (rays_o + depth * rays_d)
|
| 236 |
+
|
| 237 |
+
# features = features.unflatten(-1, (-1, 3))
|
| 238 |
+
|
| 239 |
+
opacity = torch.sigmoid(opacity)
|
| 240 |
+
if self.use_mask:
|
| 241 |
+
if torch.is_grad_enabled():
|
| 242 |
+
mask = torch.sigmoid(mask)
|
| 243 |
+
hard_mask = (mask > torch.rand_like(mask)).float()
|
| 244 |
+
opacity = opacity * (mask + (hard_mask - mask).detach())
|
| 245 |
+
else:
|
| 246 |
+
mask = torch.sigmoid(mask)
|
| 247 |
+
hard_mask = (mask > torch.rand_like(mask)).float()
|
| 248 |
+
opacity = opacity * hard_mask
|
| 249 |
+
|
| 250 |
+
fx, fy = cameras[:, 7:9].split([1, 1], dim=-1)
|
| 251 |
+
fx, fy = fx / w, fy / h
|
| 252 |
+
pixel_size = torch.sqrt(fx.pow(2) + fy.pow(2))[:, None, None, None] * depth
|
| 253 |
+
scales = (torch.sigmoid(scales) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]) * pixel_size
|
| 254 |
+
# scales = F.softplus(scales, beta=1)
|
| 255 |
+
|
| 256 |
+
# It’s not required to be normalized for gspalt rasterization?
|
| 257 |
+
rotations = torch.nn.functional.normalize(rotations, dim=-1)
|
| 258 |
+
|
| 259 |
+
gaussian_params = torch.cat([xyz, opacity, scales, rotations, features], dim=-1)
|
| 260 |
+
|
| 261 |
+
return gaussian_params
|
models/render.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from xml.dom.minidom import Notation
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from gsplat import rasterization
|
| 10 |
+
|
| 11 |
+
# torch.backends.cuda.preferred_linalg_library(backend="magma")
|
| 12 |
+
|
| 13 |
+
""""
|
| 14 |
+
modified from https://github.com/arthurhero/Long-LRM/blob/main/model/llrm.py
|
| 15 |
+
"""
|
| 16 |
+
class GaussianRendererWithCheckpoint(torch.autograd.Function):
|
| 17 |
+
@staticmethod
|
| 18 |
+
def render(xyz, feature, scale, rotation, opacity, test_c2w, test_intr,
|
| 19 |
+
W, H, sh_degree, near_plane, far_plane, backgrounds):
|
| 20 |
+
test_w2c = test_c2w.float().inverse().unsqueeze(0) # (1, 4, 4)
|
| 21 |
+
test_intr_i = torch.zeros(3, 3).to(test_intr.device)
|
| 22 |
+
test_intr_i[0, 0] = test_intr[0]
|
| 23 |
+
test_intr_i[1, 1] = test_intr[1]
|
| 24 |
+
test_intr_i[0, 2] = test_intr[2]
|
| 25 |
+
test_intr_i[1, 2] = test_intr[3]
|
| 26 |
+
test_intr_i[2, 2] = 1
|
| 27 |
+
test_intr_i = test_intr_i.unsqueeze(0) # (1, 3, 3)
|
| 28 |
+
rendering, alpha, _ = rasterization(xyz, rotation, scale, opacity, feature,
|
| 29 |
+
test_w2c, test_intr_i, W, H, sh_degree=sh_degree,
|
| 30 |
+
near_plane=near_plane, far_plane=far_plane,
|
| 31 |
+
render_mode="RGB+D",
|
| 32 |
+
backgrounds=backgrounds[None],
|
| 33 |
+
rasterize_mode='classic') # (1, H, W, 4)
|
| 34 |
+
# rendering[..., 3:] = rendering[..., 3:] + far_plane * (1 - alpha)
|
| 35 |
+
return rendering
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def forward(ctx, xyz, feature, scale, rotation, opacity, test_c2ws, test_intr,
|
| 39 |
+
W, H, sh_degree, near_plane, far_plane, backgrounds):
|
| 40 |
+
ctx.save_for_backward(xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds)
|
| 41 |
+
ctx.W = W
|
| 42 |
+
ctx.H = H
|
| 43 |
+
ctx.sh_degree = sh_degree
|
| 44 |
+
ctx.near_plane = near_plane
|
| 45 |
+
ctx.far_plane = far_plane
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
V, _ = test_intr.shape
|
| 48 |
+
renderings = torch.zeros(V, H, W, 4).to(xyz.device)
|
| 49 |
+
alphas = torch.rand(V, device=xyz.device)
|
| 50 |
+
for iv in range(V):
|
| 51 |
+
rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity,
|
| 52 |
+
test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
|
| 53 |
+
renderings[iv:iv+1] = rendering
|
| 54 |
+
|
| 55 |
+
renderings = renderings.requires_grad_()
|
| 56 |
+
return renderings
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def backward(ctx, grad_output):
|
| 60 |
+
xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds = ctx.saved_tensors
|
| 61 |
+
xyz = xyz.detach().requires_grad_()
|
| 62 |
+
feature = feature.detach().requires_grad_()
|
| 63 |
+
scale = scale.detach().requires_grad_()
|
| 64 |
+
rotation = rotation.detach().requires_grad_()
|
| 65 |
+
opacity = opacity.detach().requires_grad_()
|
| 66 |
+
W = ctx.W
|
| 67 |
+
H = ctx.H
|
| 68 |
+
sh_degree = ctx.sh_degree
|
| 69 |
+
near_plane = ctx.near_plane
|
| 70 |
+
far_plane = ctx.far_plane
|
| 71 |
+
with torch.enable_grad():
|
| 72 |
+
V, _ = test_intr.shape
|
| 73 |
+
for iv in range(V):
|
| 74 |
+
rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity,
|
| 75 |
+
test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
|
| 76 |
+
rendering.backward(grad_output[iv:iv+1])
|
| 77 |
+
|
| 78 |
+
return xyz.grad, feature.grad, scale.grad, rotation.grad, opacity.grad, None, None, None, None, None, None, None, None
|
| 79 |
+
|
| 80 |
+
def gaussian_render(gaussian_params, test_c2ws, test_intr, W, H, near_plane=0.01, far_plane=1000, use_checkpoint=False, sh_degree=0, bg_mode='random'):
|
| 81 |
+
|
| 82 |
+
if not torch.is_grad_enabled():
|
| 83 |
+
use_checkpoint = False
|
| 84 |
+
|
| 85 |
+
# opengl2colmap, see https://github.com/imlixinyang/Director3D/blob/main/modules/renderers/gaussians_renderer.py
|
| 86 |
+
test_c2ws[:, :, :3, 1:3] *= -1
|
| 87 |
+
|
| 88 |
+
device = test_intr.device
|
| 89 |
+
B, V, _ = test_intr.shape
|
| 90 |
+
|
| 91 |
+
renderings = []
|
| 92 |
+
|
| 93 |
+
for ib in range(B):
|
| 94 |
+
if bg_mode == 'random':
|
| 95 |
+
backgrounds = torch.rand(V, 3).to(device)
|
| 96 |
+
elif bg_mode == 'white':
|
| 97 |
+
backgrounds = torch.ones(V, 3).to(device)
|
| 98 |
+
elif bg_mode == 'black':
|
| 99 |
+
backgrounds = torch.zeros(V, 3).to(device)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Invalid background mode: {bg_mode}")
|
| 102 |
+
|
| 103 |
+
xyz_i, opacity_i, scale_i, rotation_i, feature_i = gaussian_params[ib].float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
|
| 104 |
+
|
| 105 |
+
opacity_i = opacity_i.squeeze(-1)
|
| 106 |
+
feature_i = feature_i.reshape(-1, (sh_degree + 1)**2, 3)
|
| 107 |
+
|
| 108 |
+
if use_checkpoint:
|
| 109 |
+
|
| 110 |
+
renderings.append(GaussianRendererWithCheckpoint.apply(xyz_i, feature_i, scale_i, rotation_i, opacity_i, test_c2ws[ib], test_intr[ib], W, H, sh_degree, near_plane, far_plane, backgrounds))
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
rendering = torch.zeros(V, H, W, 4).to(device)
|
| 114 |
+
for iv in range(V):
|
| 115 |
+
rendering[iv:iv+1] = GaussianRendererWithCheckpoint.render(xyz_i, feature_i, scale_i, rotation_i, opacity_i,
|
| 116 |
+
test_c2ws[ib][iv], test_intr[ib][iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
|
| 117 |
+
|
| 118 |
+
# test_w2c_i = test_c2ws[ib].float().inverse() # (V, 4, 4)
|
| 119 |
+
# test_intr_i = torch.zeros(V, 3, 3).to(device)
|
| 120 |
+
# test_intr_i[:, 0, 0] = test_intr[ib, :, 0]
|
| 121 |
+
# test_intr_i[:, 1, 1] = test_intr[ib, :, 1]
|
| 122 |
+
# test_intr_i[:, 0, 2] = test_intr[ib, :, 2]
|
| 123 |
+
# test_intr_i[:, 1, 2] = test_intr[ib, :, 3]
|
| 124 |
+
# test_intr_i[:, 2, 2] = 1
|
| 125 |
+
|
| 126 |
+
# # print(backgrounds.shape)
|
| 127 |
+
# rendering, _, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
|
| 128 |
+
# test_w2c_i, test_intr_i, W, H, sh_degree=sh_degree,
|
| 129 |
+
# near_plane=near_plane, far_plane=far_plane,
|
| 130 |
+
# render_mode="RGB+D",
|
| 131 |
+
# backgrounds=backgrounds,
|
| 132 |
+
# rasterize_mode='classic') # (V, H, W, 3)
|
| 133 |
+
renderings.append(rendering)
|
| 134 |
+
|
| 135 |
+
renderings = torch.stack(renderings, dim=0).permute(0, 1, 4, 2, 3).contiguous() # (B, 3, V, H, W)
|
| 136 |
+
rgb = renderings[:, :, :3].mul_(2).add_(-1).clamp(-1, 1)
|
| 137 |
+
depth = renderings[:, :, 3:]
|
| 138 |
+
return rgb, depth
|
models/transformer_wan.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 24 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 25 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import Attention
|
| 28 |
+
from diffusers.models.cache_utils import CacheMixin
|
| 29 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
| 30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from sageattention import sageattn
|
| 35 |
+
except ImportError:
|
| 36 |
+
sageattn = None
|
| 37 |
+
|
| 38 |
+
class FP32LayerNorm(nn.LayerNorm):
|
| 39 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
return F.layer_norm(
|
| 41 |
+
inputs,
|
| 42 |
+
self.normalized_shape,
|
| 43 |
+
self.weight if self.weight is not None else None,
|
| 44 |
+
self.bias if self.bias is not None else None,
|
| 45 |
+
self.eps,
|
| 46 |
+
).to(inputs.dtype)
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
class WanAttnProcessor2_0:
|
| 51 |
+
def __init__(self):
|
| 52 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 53 |
+
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
| 54 |
+
|
| 55 |
+
def __call__(
|
| 56 |
+
self,
|
| 57 |
+
attn: Attention,
|
| 58 |
+
hidden_states: torch.Tensor,
|
| 59 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 61 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
| 62 |
+
) -> torch.Tensor:
|
| 63 |
+
encoder_hidden_states_img = None
|
| 64 |
+
if attn.add_k_proj is not None:
|
| 65 |
+
# 512 is the context length of the text encoder, hardcoded for now
|
| 66 |
+
image_context_length = encoder_hidden_states.shape[1] - 512
|
| 67 |
+
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
| 68 |
+
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
| 69 |
+
if encoder_hidden_states is None:
|
| 70 |
+
encoder_hidden_states = hidden_states
|
| 71 |
+
|
| 72 |
+
query = attn.to_q(hidden_states)
|
| 73 |
+
key = attn.to_k(encoder_hidden_states)
|
| 74 |
+
value = attn.to_v(encoder_hidden_states)
|
| 75 |
+
|
| 76 |
+
if attn.norm_q is not None:
|
| 77 |
+
query = attn.norm_q(query).to(hidden_states.dtype)
|
| 78 |
+
if attn.norm_k is not None:
|
| 79 |
+
key = attn.norm_k(key).to(hidden_states.dtype)
|
| 80 |
+
|
| 81 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 82 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 83 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
if rotary_emb is not None:
|
| 86 |
+
|
| 87 |
+
def apply_rotary_emb(
|
| 88 |
+
hidden_states: torch.Tensor,
|
| 89 |
+
freqs_cos: torch.Tensor,
|
| 90 |
+
freqs_sin: torch.Tensor,
|
| 91 |
+
):
|
| 92 |
+
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
|
| 93 |
+
x1, x2 = x[..., 0], x[..., 1]
|
| 94 |
+
cos = freqs_cos[..., 0::2]
|
| 95 |
+
sin = freqs_sin[..., 1::2]
|
| 96 |
+
out = torch.empty_like(hidden_states)
|
| 97 |
+
out[..., 0::2] = x1 * cos - x2 * sin
|
| 98 |
+
out[..., 1::2] = x1 * sin + x2 * cos
|
| 99 |
+
return out.type_as(hidden_states)
|
| 100 |
+
|
| 101 |
+
query = apply_rotary_emb(query, *rotary_emb)
|
| 102 |
+
key = apply_rotary_emb(key, *rotary_emb)
|
| 103 |
+
|
| 104 |
+
# I2V task
|
| 105 |
+
hidden_states_img = None
|
| 106 |
+
if encoder_hidden_states_img is not None:
|
| 107 |
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
| 108 |
+
key_img = attn.norm_added_k(key_img)
|
| 109 |
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
| 110 |
+
|
| 111 |
+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 112 |
+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 113 |
+
|
| 114 |
+
if sageattn is not None:
|
| 115 |
+
# Ensure kernels receive fp16/bf16 tensors under autocast
|
| 116 |
+
if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
|
| 117 |
+
target_dtype = torch.bfloat16
|
| 118 |
+
query = query.to(target_dtype)
|
| 119 |
+
key_img = key_img.to(target_dtype)
|
| 120 |
+
value_img = value_img.to(target_dtype)
|
| 121 |
+
hidden_states_img = sageattn(
|
| 122 |
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
hidden_states_img = F.scaled_dot_product_attention(
|
| 126 |
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
| 130 |
+
hidden_states_img = hidden_states_img.type_as(query)
|
| 131 |
+
|
| 132 |
+
if sageattn is not None:
|
| 133 |
+
# print(query.dtype)
|
| 134 |
+
# Ensure kernels receive fp16/bf16 tensors under autocast
|
| 135 |
+
if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
|
| 136 |
+
target_dtype = torch.bfloat16
|
| 137 |
+
query = query.to(target_dtype)
|
| 138 |
+
key = key.to(target_dtype)
|
| 139 |
+
value = value.to(target_dtype)
|
| 140 |
+
hidden_states = sageattn(
|
| 141 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 145 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 149 |
+
hidden_states = hidden_states.type_as(query)
|
| 150 |
+
|
| 151 |
+
if hidden_states_img is not None:
|
| 152 |
+
hidden_states = hidden_states + hidden_states_img
|
| 153 |
+
|
| 154 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 155 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 156 |
+
return hidden_states
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class WanImageEmbedding(torch.nn.Module):
|
| 160 |
+
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
self.norm1 = FP32LayerNorm(in_features)
|
| 164 |
+
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
| 165 |
+
self.norm2 = FP32LayerNorm(out_features)
|
| 166 |
+
if pos_embed_seq_len is not None:
|
| 167 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
| 168 |
+
else:
|
| 169 |
+
self.pos_embed = None
|
| 170 |
+
|
| 171 |
+
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
if self.pos_embed is not None:
|
| 173 |
+
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
| 174 |
+
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
| 175 |
+
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
| 176 |
+
|
| 177 |
+
hidden_states = self.norm1(encoder_hidden_states_image)
|
| 178 |
+
hidden_states = self.ff(hidden_states)
|
| 179 |
+
hidden_states = self.norm2(hidden_states)
|
| 180 |
+
return hidden_states
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class WanTimeTextImageEmbedding(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
dim: int,
|
| 187 |
+
time_freq_dim: int,
|
| 188 |
+
time_proj_dim: int,
|
| 189 |
+
text_embed_dim: int,
|
| 190 |
+
image_embed_dim: Optional[int] = None,
|
| 191 |
+
pos_embed_seq_len: Optional[int] = None,
|
| 192 |
+
):
|
| 193 |
+
super().__init__()
|
| 194 |
+
|
| 195 |
+
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 196 |
+
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
| 197 |
+
self.act_fn = nn.SiLU()
|
| 198 |
+
self.time_proj = nn.Linear(dim, time_proj_dim)
|
| 199 |
+
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
| 200 |
+
|
| 201 |
+
self.image_embedder = None
|
| 202 |
+
if image_embed_dim is not None:
|
| 203 |
+
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self,
|
| 207 |
+
timestep: torch.Tensor,
|
| 208 |
+
encoder_hidden_states: torch.Tensor,
|
| 209 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 210 |
+
timestep_seq_len: Optional[int] = None,
|
| 211 |
+
):
|
| 212 |
+
timestep = self.timesteps_proj(timestep)
|
| 213 |
+
if timestep_seq_len is not None:
|
| 214 |
+
timestep = timestep.unflatten(0, (1, timestep_seq_len))
|
| 215 |
+
|
| 216 |
+
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
| 217 |
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
| 218 |
+
timestep = timestep.to(time_embedder_dtype)
|
| 219 |
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
| 220 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
| 221 |
+
|
| 222 |
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
| 223 |
+
if encoder_hidden_states_image is not None:
|
| 224 |
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
| 225 |
+
|
| 226 |
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class WanRotaryPosEmbed(nn.Module):
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
attention_head_dim: int,
|
| 233 |
+
patch_size: Tuple[int, int, int],
|
| 234 |
+
max_seq_len: int,
|
| 235 |
+
theta: float = 10000.0,
|
| 236 |
+
):
|
| 237 |
+
super().__init__()
|
| 238 |
+
|
| 239 |
+
self.attention_head_dim = attention_head_dim
|
| 240 |
+
self.patch_size = patch_size
|
| 241 |
+
self.max_seq_len = max_seq_len
|
| 242 |
+
|
| 243 |
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
| 244 |
+
t_dim = attention_head_dim - h_dim - w_dim
|
| 245 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 246 |
+
|
| 247 |
+
freqs_cos = []
|
| 248 |
+
freqs_sin = []
|
| 249 |
+
|
| 250 |
+
for dim in [t_dim, h_dim, w_dim]:
|
| 251 |
+
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
| 252 |
+
dim,
|
| 253 |
+
max_seq_len,
|
| 254 |
+
theta,
|
| 255 |
+
use_real=True,
|
| 256 |
+
repeat_interleave_real=True,
|
| 257 |
+
freqs_dtype=freqs_dtype,
|
| 258 |
+
)
|
| 259 |
+
freqs_cos.append(freq_cos)
|
| 260 |
+
freqs_sin.append(freq_sin)
|
| 261 |
+
|
| 262 |
+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
| 263 |
+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
| 264 |
+
|
| 265 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 266 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 267 |
+
p_t, p_h, p_w = self.patch_size
|
| 268 |
+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
| 269 |
+
|
| 270 |
+
split_sizes = [
|
| 271 |
+
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
| 272 |
+
self.attention_head_dim // 3,
|
| 273 |
+
self.attention_head_dim // 3,
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
| 277 |
+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
| 278 |
+
|
| 279 |
+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
| 280 |
+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
| 281 |
+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
| 282 |
+
|
| 283 |
+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
| 284 |
+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
| 285 |
+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
| 286 |
+
|
| 287 |
+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
| 288 |
+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
| 289 |
+
|
| 290 |
+
return freqs_cos, freqs_sin
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@maybe_allow_in_graph
|
| 294 |
+
class WanTransformerBlock(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
dim: int,
|
| 298 |
+
ffn_dim: int,
|
| 299 |
+
num_heads: int,
|
| 300 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 301 |
+
cross_attn_norm: bool = False,
|
| 302 |
+
eps: float = 1e-6,
|
| 303 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 304 |
+
):
|
| 305 |
+
super().__init__()
|
| 306 |
+
|
| 307 |
+
# 1. Self-attention
|
| 308 |
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 309 |
+
self.attn1 = Attention(
|
| 310 |
+
query_dim=dim,
|
| 311 |
+
heads=num_heads,
|
| 312 |
+
kv_heads=num_heads,
|
| 313 |
+
dim_head=dim // num_heads,
|
| 314 |
+
qk_norm=qk_norm,
|
| 315 |
+
eps=eps,
|
| 316 |
+
bias=True,
|
| 317 |
+
cross_attention_dim=None,
|
| 318 |
+
out_bias=True,
|
| 319 |
+
processor=WanAttnProcessor2_0(),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# 2. Cross-attention
|
| 323 |
+
self.attn2 = Attention(
|
| 324 |
+
query_dim=dim,
|
| 325 |
+
heads=num_heads,
|
| 326 |
+
kv_heads=num_heads,
|
| 327 |
+
dim_head=dim // num_heads,
|
| 328 |
+
qk_norm=qk_norm,
|
| 329 |
+
eps=eps,
|
| 330 |
+
bias=True,
|
| 331 |
+
cross_attention_dim=None,
|
| 332 |
+
out_bias=True,
|
| 333 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
| 334 |
+
added_proj_bias=True,
|
| 335 |
+
processor=WanAttnProcessor2_0(),
|
| 336 |
+
)
|
| 337 |
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 338 |
+
|
| 339 |
+
# 3. Feed-forward
|
| 340 |
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
| 341 |
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 342 |
+
|
| 343 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
hidden_states: torch.Tensor,
|
| 348 |
+
encoder_hidden_states: torch.Tensor,
|
| 349 |
+
temb: torch.Tensor,
|
| 350 |
+
rotary_emb: torch.Tensor,
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
if temb.ndim == 4:
|
| 353 |
+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
| 354 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
| 355 |
+
self.scale_shift_table.unsqueeze(0) + temb
|
| 356 |
+
).chunk(6, dim=2)
|
| 357 |
+
# batch_size, seq_len, 1, inner_dim
|
| 358 |
+
shift_msa = shift_msa.squeeze(2)
|
| 359 |
+
scale_msa = scale_msa.squeeze(2)
|
| 360 |
+
gate_msa = gate_msa.squeeze(2)
|
| 361 |
+
c_shift_msa = c_shift_msa.squeeze(2)
|
| 362 |
+
c_scale_msa = c_scale_msa.squeeze(2)
|
| 363 |
+
c_gate_msa = c_gate_msa.squeeze(2)
|
| 364 |
+
else:
|
| 365 |
+
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
| 366 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
| 367 |
+
self.scale_shift_table + temb
|
| 368 |
+
).chunk(6, dim=1)
|
| 369 |
+
|
| 370 |
+
# print(hidden_states.dtype)
|
| 371 |
+
|
| 372 |
+
# 1. Self-attention
|
| 373 |
+
norm_hidden_states = (self.norm1(hidden_states).mul_(1 + scale_msa).add_(shift_msa))
|
| 374 |
+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
| 375 |
+
hidden_states += attn_output * gate_msa
|
| 376 |
+
# hidden_states = hidden_states.type_as(hidden_states)
|
| 377 |
+
|
| 378 |
+
# print(hidden_states.dtype)
|
| 379 |
+
|
| 380 |
+
# 2. Cross-attention
|
| 381 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 382 |
+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
| 383 |
+
hidden_states += attn_output
|
| 384 |
+
|
| 385 |
+
# print(hidden_states.dtype)
|
| 386 |
+
|
| 387 |
+
# 3. Feed-forward
|
| 388 |
+
norm_hidden_states = (self.norm3(hidden_states).mul_(1 + c_scale_msa).add_(c_shift_msa))
|
| 389 |
+
ff_output = self.ffn(norm_hidden_states)
|
| 390 |
+
hidden_states += ff_output.mul_(c_gate_msa)
|
| 391 |
+
# hidden_states = hidden_states.type_as(hidden_states)
|
| 392 |
+
|
| 393 |
+
return hidden_states
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
| 397 |
+
r"""
|
| 398 |
+
A Transformer model for video-like data used in the Wan model.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
| 402 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
| 403 |
+
num_attention_heads (`int`, defaults to `40`):
|
| 404 |
+
Fixed length for text embeddings.
|
| 405 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 406 |
+
The number of channels in each head.
|
| 407 |
+
in_channels (`int`, defaults to `16`):
|
| 408 |
+
The number of channels in the input.
|
| 409 |
+
out_channels (`int`, defaults to `16`):
|
| 410 |
+
The number of channels in the output.
|
| 411 |
+
text_dim (`int`, defaults to `512`):
|
| 412 |
+
Input dimension for text embeddings.
|
| 413 |
+
freq_dim (`int`, defaults to `256`):
|
| 414 |
+
Dimension for sinusoidal time embeddings.
|
| 415 |
+
ffn_dim (`int`, defaults to `13824`):
|
| 416 |
+
Intermediate dimension in feed-forward network.
|
| 417 |
+
num_layers (`int`, defaults to `40`):
|
| 418 |
+
The number of layers of transformer blocks to use.
|
| 419 |
+
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
| 420 |
+
Window size for local attention (-1 indicates global attention).
|
| 421 |
+
cross_attn_norm (`bool`, defaults to `True`):
|
| 422 |
+
Enable cross-attention normalization.
|
| 423 |
+
qk_norm (`bool`, defaults to `True`):
|
| 424 |
+
Enable query/key normalization.
|
| 425 |
+
eps (`float`, defaults to `1e-6`):
|
| 426 |
+
Epsilon value for normalization layers.
|
| 427 |
+
add_img_emb (`bool`, defaults to `False`):
|
| 428 |
+
Whether to use img_emb.
|
| 429 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 430 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
_supports_gradient_checkpointing = True
|
| 434 |
+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
| 435 |
+
_no_split_modules = ["WanTransformerBlock"]
|
| 436 |
+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
| 437 |
+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
| 438 |
+
_repeated_blocks = ["WanTransformerBlock"]
|
| 439 |
+
|
| 440 |
+
@register_to_config
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
patch_size: Tuple[int] = (1, 2, 2),
|
| 444 |
+
num_attention_heads: int = 40,
|
| 445 |
+
attention_head_dim: int = 128,
|
| 446 |
+
in_channels: int = 16,
|
| 447 |
+
out_channels: int = 16,
|
| 448 |
+
text_dim: int = 4096,
|
| 449 |
+
freq_dim: int = 256,
|
| 450 |
+
ffn_dim: int = 13824,
|
| 451 |
+
num_layers: int = 40,
|
| 452 |
+
cross_attn_norm: bool = True,
|
| 453 |
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
| 454 |
+
eps: float = 1e-6,
|
| 455 |
+
image_dim: Optional[int] = None,
|
| 456 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 457 |
+
rope_max_seq_len: int = 1024,
|
| 458 |
+
pos_embed_seq_len: Optional[int] = None,
|
| 459 |
+
) -> None:
|
| 460 |
+
super().__init__()
|
| 461 |
+
|
| 462 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 463 |
+
out_channels = out_channels or in_channels
|
| 464 |
+
|
| 465 |
+
# 1. Patch & position embedding
|
| 466 |
+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
| 467 |
+
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
| 468 |
+
|
| 469 |
+
# 2. Condition embeddings
|
| 470 |
+
# image_embedding_dim=1280 for I2V model
|
| 471 |
+
self.condition_embedder = WanTimeTextImageEmbedding(
|
| 472 |
+
dim=inner_dim,
|
| 473 |
+
time_freq_dim=freq_dim,
|
| 474 |
+
time_proj_dim=inner_dim * 6,
|
| 475 |
+
text_embed_dim=text_dim,
|
| 476 |
+
image_embed_dim=image_dim,
|
| 477 |
+
pos_embed_seq_len=pos_embed_seq_len,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# 3. Transformer blocks
|
| 481 |
+
self.blocks = nn.ModuleList(
|
| 482 |
+
[
|
| 483 |
+
WanTransformerBlock(
|
| 484 |
+
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
| 485 |
+
)
|
| 486 |
+
for _ in range(num_layers)
|
| 487 |
+
]
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# 4. Output norm & projection
|
| 491 |
+
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
| 492 |
+
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
| 493 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
| 494 |
+
|
| 495 |
+
self.gradient_checkpointing = False
|
| 496 |
+
|
| 497 |
+
def forward(
|
| 498 |
+
self,
|
| 499 |
+
hidden_states: torch.Tensor,
|
| 500 |
+
timestep: torch.LongTensor,
|
| 501 |
+
encoder_hidden_states: torch.Tensor,
|
| 502 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 503 |
+
return_dict: bool = True,
|
| 504 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 505 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 506 |
+
if attention_kwargs is not None:
|
| 507 |
+
attention_kwargs = attention_kwargs.copy()
|
| 508 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 509 |
+
else:
|
| 510 |
+
lora_scale = 1.0
|
| 511 |
+
|
| 512 |
+
if USE_PEFT_BACKEND:
|
| 513 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 514 |
+
scale_lora_layers(self, lora_scale)
|
| 515 |
+
else:
|
| 516 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 517 |
+
logger.warning(
|
| 518 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 522 |
+
p_t, p_h, p_w = self.config.patch_size
|
| 523 |
+
post_patch_num_frames = num_frames // p_t
|
| 524 |
+
post_patch_height = height // p_h
|
| 525 |
+
post_patch_width = width // p_w
|
| 526 |
+
|
| 527 |
+
rotary_emb = self.rope(hidden_states)
|
| 528 |
+
|
| 529 |
+
hidden_states = self.patch_embedding(hidden_states)
|
| 530 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 531 |
+
|
| 532 |
+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
| 533 |
+
if timestep.ndim == 2:
|
| 534 |
+
ts_seq_len = timestep.shape[1]
|
| 535 |
+
timestep = timestep.flatten() # batch_size * seq_len
|
| 536 |
+
else:
|
| 537 |
+
ts_seq_len = None
|
| 538 |
+
|
| 539 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
| 540 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
| 541 |
+
)
|
| 542 |
+
if ts_seq_len is not None:
|
| 543 |
+
# batch_size, seq_len, 6, inner_dim
|
| 544 |
+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
| 545 |
+
else:
|
| 546 |
+
# batch_size, 6, inner_dim
|
| 547 |
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
| 548 |
+
|
| 549 |
+
if encoder_hidden_states_image is not None:
|
| 550 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 551 |
+
|
| 552 |
+
if True:
|
| 553 |
+
encoder_hidden_states = encoder_hidden_states.to(torch.bfloat16)
|
| 554 |
+
timestep_proj = timestep_proj.to(torch.bfloat16)
|
| 555 |
+
rotary_emb = [rotary_emb[0].to(torch.bfloat16), rotary_emb[1].to(torch.bfloat16)]
|
| 556 |
+
hidden_states = hidden_states.to(torch.bfloat16)
|
| 557 |
+
|
| 558 |
+
# 4. Transformer blocks
|
| 559 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 560 |
+
for block in self.blocks:
|
| 561 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 562 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
| 563 |
+
)
|
| 564 |
+
else:
|
| 565 |
+
for block in self.blocks:
|
| 566 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
| 567 |
+
|
| 568 |
+
# 5. Output norm, projection & unpatchify
|
| 569 |
+
if temb.ndim == 3:
|
| 570 |
+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
| 571 |
+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
| 572 |
+
shift = shift.squeeze(2)
|
| 573 |
+
scale = scale.squeeze(2)
|
| 574 |
+
else:
|
| 575 |
+
# batch_size, inner_dim
|
| 576 |
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
| 577 |
+
|
| 578 |
+
# Move the shift and scale tensors to the same device as hidden_states.
|
| 579 |
+
# When using multi-GPU inference via accelerate these will be on the
|
| 580 |
+
# first device rather than the last device, which hidden_states ends up
|
| 581 |
+
# on.
|
| 582 |
+
shift = shift.to(hidden_states.device)
|
| 583 |
+
scale = scale.to(hidden_states.device)
|
| 584 |
+
|
| 585 |
+
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
| 586 |
+
hidden_states = self.proj_out(hidden_states)
|
| 587 |
+
|
| 588 |
+
hidden_states = hidden_states.reshape(
|
| 589 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
| 590 |
+
)
|
| 591 |
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
| 592 |
+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 593 |
+
|
| 594 |
+
if USE_PEFT_BACKEND:
|
| 595 |
+
# remove `lora_scale` from each PEFT layer
|
| 596 |
+
unscale_lora_layers(self, lora_scale)
|
| 597 |
+
|
| 598 |
+
if not return_dict:
|
| 599 |
+
return (output,)
|
| 600 |
+
|
| 601 |
+
return Transformer2DModelOutput(sample=output)
|
quant.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cleanup_memory():
|
| 9 |
+
gc.collect()
|
| 10 |
+
torch.cuda.empty_cache()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
| 14 |
+
"""Quantize a tensor using per-tensor static scaling factor.
|
| 15 |
+
Args:
|
| 16 |
+
tensor: The input tensor.
|
| 17 |
+
"""
|
| 18 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
| 19 |
+
# Calculate the scale as dtype max divided by absmax.
|
| 20 |
+
# Since .abs() creates a new tensor, we use aminmax to get
|
| 21 |
+
# the min and max first and then calculate the absmax.
|
| 22 |
+
if tensor.numel() == 0:
|
| 23 |
+
# Deal with empty tensors (triggered by empty MoE experts)
|
| 24 |
+
min_val, max_val = (
|
| 25 |
+
torch.tensor(-16.0, dtype=tensor.dtype),
|
| 26 |
+
torch.tensor(16.0, dtype=tensor.dtype),
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
min_val, max_val = tensor.aminmax()
|
| 30 |
+
amax = torch.maximum(min_val.abs(), max_val.abs())
|
| 31 |
+
scale = finfo.max / amax.clamp(min=1e-12)
|
| 32 |
+
# scale and clamp the tensor to bring it to
|
| 33 |
+
# the representative range of float8 data type
|
| 34 |
+
# (as default cast is unsaturated)
|
| 35 |
+
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
| 36 |
+
# Return both float8 data and the inverse scale (as float),
|
| 37 |
+
# as both required as inputs to torch._scaled_mm
|
| 38 |
+
qweight = qweight.to(torch.float8_e4m3fn)
|
| 39 |
+
scale = scale.float().reciprocal()
|
| 40 |
+
return qweight, scale
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
|
| 44 |
+
"""Quantizes a floating-point tensor to FP8 (E4M3 format) using static scaling.
|
| 45 |
+
|
| 46 |
+
Performs uniform quantization of the input tensor by:
|
| 47 |
+
1. Scaling the tensor values using the provided inverse scale factor
|
| 48 |
+
2. Clamping values to the representable range of FP8 E4M3 format
|
| 49 |
+
3. Converting to FP8 data type
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
tensor (torch.Tensor): Input tensor to be quantized (any floating-point dtype)
|
| 53 |
+
inv_scale (float): Inverse of the quantization scale factor (1/scale)
|
| 54 |
+
(Must be pre-calculated based on tensor statistics)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
torch.Tensor: Quantized tensor in torch.float8_e4m3fn format
|
| 58 |
+
|
| 59 |
+
Note:
|
| 60 |
+
- Uses the E4M3 format (4 exponent bits, 3 mantissa bits, no infinity/nan)
|
| 61 |
+
- This is a static quantization (scale factor must be pre-determined)
|
| 62 |
+
- For dynamic quantization, see per_tensor_quantize()
|
| 63 |
+
"""
|
| 64 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
| 65 |
+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
| 66 |
+
return qweight.to(torch.float8_e4m3fn)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
|
| 70 |
+
"""Performs FP8 GEMM (General Matrix Multiplication) operation with optional native hardware support.
|
| 71 |
+
Args:
|
| 72 |
+
A (torch.Tensor): Input tensor A (FP8 or other dtype)
|
| 73 |
+
A_scale (torch.Tensor/float): Scale factor for tensor A
|
| 74 |
+
B (torch.Tensor): Input tensor B (FP8 or other dtype)
|
| 75 |
+
B_scale (torch.Tensor/float): Scale factor for tensor B
|
| 76 |
+
bias (torch.Tensor/None): Optional bias tensor
|
| 77 |
+
out_dtype (torch.dtype): Output data type
|
| 78 |
+
native_fp8_support (bool): Whether to use hardware-accelerated FP8 operations
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: Result of GEMM operation
|
| 82 |
+
"""
|
| 83 |
+
if A.numel() == 0:
|
| 84 |
+
# Deal with empty tensors (triggeted by empty MoE experts)
|
| 85 |
+
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
|
| 86 |
+
|
| 87 |
+
if native_fp8_support:
|
| 88 |
+
need_reshape = A.dim() == 3
|
| 89 |
+
if need_reshape:
|
| 90 |
+
batch_size = A.shape[0]
|
| 91 |
+
A_input = A.reshape(-1, A.shape[-1])
|
| 92 |
+
else:
|
| 93 |
+
batch_size = None
|
| 94 |
+
A_input = A
|
| 95 |
+
output = torch._scaled_mm(
|
| 96 |
+
A_input,
|
| 97 |
+
B.t(),
|
| 98 |
+
out_dtype=out_dtype,
|
| 99 |
+
scale_a=A_scale,
|
| 100 |
+
scale_b=B_scale,
|
| 101 |
+
bias=bias,
|
| 102 |
+
)
|
| 103 |
+
if need_reshape:
|
| 104 |
+
output = output.reshape(
|
| 105 |
+
batch_size, output.shape[0] // batch_size, output.shape[1]
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
output = torch.nn.functional.linear(
|
| 109 |
+
A.to(out_dtype) * A_scale,
|
| 110 |
+
B.to(out_dtype) * B_scale.to(out_dtype),
|
| 111 |
+
bias=bias,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return output
|
| 115 |
+
|
| 116 |
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
| 117 |
+
if "." in name:
|
| 118 |
+
parent_name = name.rsplit(".", 1)[0]
|
| 119 |
+
child_name = name[len(parent_name) + 1:]
|
| 120 |
+
parent = model.get_submodule(parent_name)
|
| 121 |
+
else:
|
| 122 |
+
parent_name = ""
|
| 123 |
+
parent = model
|
| 124 |
+
child_name = name
|
| 125 |
+
setattr(parent, child_name, new_module)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Class responsible for quantizing weights
|
| 129 |
+
class FP8DynamicLinear(torch.nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
weight: torch.Tensor,
|
| 133 |
+
weight_scale: torch.Tensor,
|
| 134 |
+
bias: torch.nn.Parameter,
|
| 135 |
+
native_fp8_support: bool = False,
|
| 136 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
| 140 |
+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
| 141 |
+
self.bias = bias
|
| 142 |
+
self.native_fp8_support = native_fp8_support
|
| 143 |
+
self.dtype = dtype
|
| 144 |
+
|
| 145 |
+
# @torch.compile
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
if x.dtype !=self.dtype:
|
| 148 |
+
x = x.to(self.dtype)
|
| 149 |
+
qinput, x_scale = per_tensor_quantize(x)
|
| 150 |
+
output = fp8_gemm(
|
| 151 |
+
A=qinput,
|
| 152 |
+
A_scale=x_scale,
|
| 153 |
+
B=self.weight,
|
| 154 |
+
B_scale=self.weight_scale,
|
| 155 |
+
bias=self.bias,
|
| 156 |
+
out_dtype=x.dtype,
|
| 157 |
+
native_fp8_support=self.native_fp8_support,
|
| 158 |
+
)
|
| 159 |
+
return output
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def FluxFp8GeMMProcessor(model: torch.nn.Module):
|
| 163 |
+
"""Processes a PyTorch model to convert eligible Linear layers to FP8 precision.
|
| 164 |
+
|
| 165 |
+
This function performs the following operations:
|
| 166 |
+
1. Checks for native FP8 support on the current GPU
|
| 167 |
+
2. Identifies target Linear layers in transformer blocks
|
| 168 |
+
3. Quantizes weights to FP8 format
|
| 169 |
+
4. Replaces original Linear layers with FP8DynamicLinear versions
|
| 170 |
+
5. Performs memory cleanup
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
model (torch.nn.Module): The neural network model to be processed.
|
| 174 |
+
Should contain transformer blocks with Linear layers.
|
| 175 |
+
"""
|
| 176 |
+
native_fp8_support = (
|
| 177 |
+
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
|
| 178 |
+
)
|
| 179 |
+
named_modules = list(model.named_modules())
|
| 180 |
+
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
|
| 181 |
+
if isinstance(linear, torch.nn.Linear) and "blocks" in name:
|
| 182 |
+
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
|
| 183 |
+
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
|
| 184 |
+
quant_linear = FP8DynamicLinear(
|
| 185 |
+
weight=quant_weight,
|
| 186 |
+
weight_scale=weight_scale,
|
| 187 |
+
bias=bias,
|
| 188 |
+
native_fp8_support=native_fp8_support,
|
| 189 |
+
dtype=linear.weight.dtype
|
| 190 |
+
)
|
| 191 |
+
replace_module(model, name, quant_linear)
|
| 192 |
+
del linear.weight
|
| 193 |
+
del linear.bias
|
| 194 |
+
del linear
|
| 195 |
+
cleanup_memory()
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
torchvision==0.21.0
|
| 3 |
+
triton==3.2.0
|
| 4 |
+
transformers==4.57.0
|
| 5 |
+
omegaconf==2.3.0
|
| 6 |
+
ninja==1.13.0
|
| 7 |
+
numpy==2.2.6
|
| 8 |
+
einops==0.8.1
|
| 9 |
+
moviepy==1.0.3
|
| 10 |
+
opencv-python==4.12.0.88
|
| 11 |
+
av==15.1.0
|
| 12 |
+
plyfile==1.1.2
|
| 13 |
+
ftfy==6.3.1
|
| 14 |
+
flask==3.1.2
|
| 15 |
+
gradio==5.49.1
|
| 16 |
+
gsplat==1.5.2
|
| 17 |
+
accelerate==1.10.1
|
| 18 |
+
git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
|
| 19 |
+
git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
|
utils.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import importlib
|
| 8 |
+
from plyfile import PlyData, PlyElement
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
|
| 12 |
+
class EmbedContainer(nn.Module):
|
| 13 |
+
def __init__(self, tensor):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.tensor = nn.Parameter(tensor)
|
| 16 |
+
|
| 17 |
+
def forward(self):
|
| 18 |
+
return self.tensor
|
| 19 |
+
|
| 20 |
+
@torch.no_grad
|
| 21 |
+
def zero_init(module):
|
| 22 |
+
if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear:
|
| 23 |
+
module.weight.zero_()
|
| 24 |
+
module.bias.zero_()
|
| 25 |
+
return module
|
| 26 |
+
|
| 27 |
+
def import_str(string):
|
| 28 |
+
# From https://github.com/CompVis/taming-transformers
|
| 29 |
+
module, cls = string.rsplit(".", 1)
|
| 30 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
from https://github.com/Kai-46/minFM/blob/main/utils/ema.py
|
| 34 |
+
Exponential Moving Average (EMA) utilities for PyTorch models.
|
| 35 |
+
|
| 36 |
+
This module provides utilities for maintaining and updating EMA models,
|
| 37 |
+
which are commonly used to improve model stability and generalization
|
| 38 |
+
in training deep neural networks. It supports both regular tensors and
|
| 39 |
+
DTensors (from FSDP-wrapped models).
|
| 40 |
+
"""
|
| 41 |
+
class EMA_FSDP:
|
| 42 |
+
def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
|
| 43 |
+
self.decay = decay
|
| 44 |
+
self.shadow = {}
|
| 45 |
+
self._init_shadow(fsdp_module)
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def _init_shadow(self, fsdp_module):
|
| 49 |
+
# 判断是否是FSDP模型
|
| 50 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 51 |
+
if isinstance(fsdp_module, FSDP):
|
| 52 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
| 53 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 54 |
+
self.shadow[n] = p.detach().clone().float().cpu()
|
| 55 |
+
else:
|
| 56 |
+
for n, p in fsdp_module.named_parameters():
|
| 57 |
+
self.shadow[n] = p.detach().clone().float().cpu()
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def update(self, fsdp_module):
|
| 61 |
+
d = self.decay
|
| 62 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 63 |
+
if isinstance(fsdp_module, FSDP):
|
| 64 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
| 65 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 66 |
+
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
|
| 67 |
+
else:
|
| 68 |
+
for n, p in fsdp_module.named_parameters():
|
| 69 |
+
print(n, self.shadow[n])
|
| 70 |
+
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
|
| 71 |
+
|
| 72 |
+
# Optional helpers ---------------------------------------------------
|
| 73 |
+
def state_dict(self):
|
| 74 |
+
return self.shadow # picklable
|
| 75 |
+
|
| 76 |
+
def load_state_dict(self, sd):
|
| 77 |
+
self.shadow = {k: v.clone() for k, v in sd.items()}
|
| 78 |
+
|
| 79 |
+
def copy_to(self, fsdp_module):
|
| 80 |
+
# load EMA weights into an (unwrapped) copy of the generator
|
| 81 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 82 |
+
with FSDP.summon_full_params(fsdp_module, writeback=True):
|
| 83 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 84 |
+
if n in self.shadow:
|
| 85 |
+
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
|
| 86 |
+
|
| 87 |
+
def create_raymaps(cameras, h, w):
|
| 88 |
+
rays_o, rays_d = create_rays(cameras, h, w)
|
| 89 |
+
raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1)
|
| 90 |
+
return raymaps
|
| 91 |
+
|
| 92 |
+
# def create_raymaps(cameras, h, w):
|
| 93 |
+
# rays_o, rays_d = create_rays(cameras, h, w)
|
| 94 |
+
# raymaps = torch.cat([rays_d, torch.cross(rays_d, rays_o, dim=-1)], dim=-1)
|
| 95 |
+
# return raymaps
|
| 96 |
+
|
| 97 |
+
class EMANorm(nn.Module):
|
| 98 |
+
def __init__(self, beta):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.register_buffer('magnitude_ema', torch.ones([]))
|
| 101 |
+
self.beta = beta
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
if self.training:
|
| 105 |
+
magnitude_cur = x.detach().to(torch.float32).square().mean()
|
| 106 |
+
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta))
|
| 107 |
+
input_gain = self.magnitude_ema.rsqrt()
|
| 108 |
+
x = x.mul(input_gain)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
class TimestepEmbedding(nn.Module):
|
| 112 |
+
def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.max_period = max_period
|
| 115 |
+
self.time_factor = time_factor
|
| 116 |
+
self.dim = dim
|
| 117 |
+
if zero_weight:
|
| 118 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
| 119 |
+
else:
|
| 120 |
+
self.weight = None
|
| 121 |
+
|
| 122 |
+
def forward(self, t):
|
| 123 |
+
if self.weight is None:
|
| 124 |
+
return timestep_embedding(t, self.dim, self.max_period, self.time_factor)
|
| 125 |
+
else:
|
| 126 |
+
return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
| 129 |
+
def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
|
| 130 |
+
"""
|
| 131 |
+
Create sinusoidal timestep embeddings.
|
| 132 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 133 |
+
These may be fractional.
|
| 134 |
+
:param dim: the dimension of the output.
|
| 135 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 136 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 137 |
+
"""
|
| 138 |
+
t = time_factor * t
|
| 139 |
+
half = dim // 2
|
| 140 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
| 141 |
+
|
| 142 |
+
args = t[:, None].float() * freqs[None]
|
| 143 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 144 |
+
if dim % 2:
|
| 145 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 146 |
+
if torch.is_floating_point(t):
|
| 147 |
+
embedding = embedding.to(t)
|
| 148 |
+
return embedding
|
| 149 |
+
|
| 150 |
+
def quaternion_to_matrix(quaternions):
|
| 151 |
+
"""
|
| 152 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 153 |
+
Args:
|
| 154 |
+
quaternions: quaternions with real part first,
|
| 155 |
+
as tensor of shape (..., 4).
|
| 156 |
+
Returns:
|
| 157 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 158 |
+
"""
|
| 159 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 160 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 161 |
+
|
| 162 |
+
o = torch.stack(
|
| 163 |
+
(
|
| 164 |
+
1 - two_s * (j * j + k * k),
|
| 165 |
+
two_s * (i * j - k * r),
|
| 166 |
+
two_s * (i * k + j * r),
|
| 167 |
+
two_s * (i * j + k * r),
|
| 168 |
+
1 - two_s * (i * i + k * k),
|
| 169 |
+
two_s * (j * k - i * r),
|
| 170 |
+
two_s * (i * k - j * r),
|
| 171 |
+
two_s * (j * k + i * r),
|
| 172 |
+
1 - two_s * (i * i + j * j),
|
| 173 |
+
),
|
| 174 |
+
-1,
|
| 175 |
+
)
|
| 176 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 177 |
+
|
| 178 |
+
# from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
|
| 179 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
"""
|
| 181 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 182 |
+
part is non negative.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
quaternions: Quaternions with real part first,
|
| 186 |
+
as tensor of shape (..., 4).
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 190 |
+
"""
|
| 191 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 192 |
+
|
| 193 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
"""
|
| 195 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 196 |
+
but with a zero subgradient where x is 0.
|
| 197 |
+
"""
|
| 198 |
+
ret = torch.zeros_like(x)
|
| 199 |
+
positive_mask = x > 0
|
| 200 |
+
if torch.is_grad_enabled():
|
| 201 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 202 |
+
else:
|
| 203 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 204 |
+
return ret
|
| 205 |
+
|
| 206 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
"""
|
| 208 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 215 |
+
"""
|
| 216 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 217 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 218 |
+
|
| 219 |
+
batch_dim = matrix.shape[:-2]
|
| 220 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 221 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
q_abs = _sqrt_positive_part(
|
| 225 |
+
torch.stack(
|
| 226 |
+
[
|
| 227 |
+
1.0 + m00 + m11 + m22,
|
| 228 |
+
1.0 + m00 - m11 - m22,
|
| 229 |
+
1.0 - m00 + m11 - m22,
|
| 230 |
+
1.0 - m00 - m11 + m22,
|
| 231 |
+
],
|
| 232 |
+
dim=-1,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 237 |
+
quat_by_rijk = torch.stack(
|
| 238 |
+
[
|
| 239 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 240 |
+
# `int`.
|
| 241 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 242 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 243 |
+
# `int`.
|
| 244 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 245 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 246 |
+
# `int`.
|
| 247 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 248 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 249 |
+
# `int`.
|
| 250 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 251 |
+
],
|
| 252 |
+
dim=-2,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 256 |
+
# the candidate won't be picked.
|
| 257 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 258 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 259 |
+
|
| 260 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 261 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 262 |
+
indices = q_abs.argmax(dim=-1, keepdim=True)
|
| 263 |
+
expand_dims = list(batch_dim) + [1, 4]
|
| 264 |
+
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
|
| 265 |
+
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
|
| 266 |
+
return standardize_quaternion(out)
|
| 267 |
+
|
| 268 |
+
@torch.amp.autocast(device_type="cuda", enabled=False)
|
| 269 |
+
def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None):
|
| 270 |
+
B, N = cameras.shape[:2]
|
| 271 |
+
|
| 272 |
+
c2ws = torch.zeros(B, N, 3, 4, device=cameras.device)
|
| 273 |
+
|
| 274 |
+
c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4])
|
| 275 |
+
c2ws[..., :, 3] = cameras[..., 4:7]
|
| 276 |
+
|
| 277 |
+
_c2ws = c2ws
|
| 278 |
+
|
| 279 |
+
ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c
|
| 280 |
+
_c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :]
|
| 281 |
+
|
| 282 |
+
if n_frame is not None:
|
| 283 |
+
T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
|
| 284 |
+
else:
|
| 285 |
+
T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
|
| 286 |
+
|
| 287 |
+
_c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2)
|
| 288 |
+
|
| 289 |
+
R = matrix_to_quaternion(_c2ws[..., :3, :3])
|
| 290 |
+
T = _c2ws[..., :3, 3]
|
| 291 |
+
cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1)
|
| 292 |
+
|
| 293 |
+
if return_meta:
|
| 294 |
+
return cameras, ref_w2c, T_norm
|
| 295 |
+
else:
|
| 296 |
+
return cameras
|
| 297 |
+
|
| 298 |
+
def create_rays(cameras, h, w, uv_offset=None):
|
| 299 |
+
prefix_shape = cameras.shape[:-1]
|
| 300 |
+
cameras = cameras.flatten(0, -2)
|
| 301 |
+
device = cameras.device
|
| 302 |
+
N = cameras.shape[0]
|
| 303 |
+
|
| 304 |
+
c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1)
|
| 305 |
+
c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4])
|
| 306 |
+
c2w[:, :3, 3] = cameras[:, 4:7]
|
| 307 |
+
|
| 308 |
+
# fx, fy, cx, cy should be divided by original H, W
|
| 309 |
+
fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1)
|
| 310 |
+
|
| 311 |
+
fx, cx = fx * w, cx * w
|
| 312 |
+
fy, cy = fy * h, cy * h
|
| 313 |
+
|
| 314 |
+
inds = torch.arange(0, h*w, device=device).expand(N, h*w)
|
| 315 |
+
|
| 316 |
+
i = inds % w + 0.5
|
| 317 |
+
j = torch.div(inds, w, rounding_mode='floor') + 0.5
|
| 318 |
+
|
| 319 |
+
u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0)
|
| 320 |
+
v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0)
|
| 321 |
+
|
| 322 |
+
zs = - torch.ones_like(i)
|
| 323 |
+
xs = - (u - 1) * cx / fx * zs
|
| 324 |
+
ys = (v - 1) * cy / fy * zs
|
| 325 |
+
directions = torch.stack((xs, ys, zs), dim=-1)
|
| 326 |
+
|
| 327 |
+
rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1)
|
| 328 |
+
|
| 329 |
+
rays_o = c2w[..., :3, 3] # [B, 3]
|
| 330 |
+
rays_o = rays_o[..., None, :].expand_as(rays_d)
|
| 331 |
+
|
| 332 |
+
rays_o = rays_o.reshape(*prefix_shape, h, w, 3)
|
| 333 |
+
rays_d = rays_d.reshape(*prefix_shape, h, w, 3)
|
| 334 |
+
|
| 335 |
+
return rays_o, rays_d
|
| 336 |
+
|
| 337 |
+
def matrix_to_square(mat):
|
| 338 |
+
l = len(mat.shape)
|
| 339 |
+
if l==3:
|
| 340 |
+
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1)
|
| 341 |
+
elif l==4:
|
| 342 |
+
return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
|
| 343 |
+
|
| 344 |
+
def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
|
| 345 |
+
|
| 346 |
+
sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
|
| 347 |
+
|
| 348 |
+
xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
|
| 349 |
+
|
| 350 |
+
means3D = xyz.contiguous().float()
|
| 351 |
+
opacity = opacity.contiguous().float()
|
| 352 |
+
scales = scale.contiguous().float()
|
| 353 |
+
rotations = rotation.contiguous().float()
|
| 354 |
+
shs = feature.contiguous().float() # [N, 1, 3]
|
| 355 |
+
|
| 356 |
+
# print(means3D.shape, opacity.shape, scales.shape, rotations.shape, shs.shape)
|
| 357 |
+
|
| 358 |
+
# prune by opacity
|
| 359 |
+
if opacity_threshold > 0:
|
| 360 |
+
mask = opacity[..., 0] >= opacity_threshold
|
| 361 |
+
means3D = means3D[mask]
|
| 362 |
+
opacity = opacity[mask]
|
| 363 |
+
scales = scales[mask]
|
| 364 |
+
rotations = rotations[mask]
|
| 365 |
+
shs = shs[mask]
|
| 366 |
+
|
| 367 |
+
print("Gaussian percentage: ", mask.float().mean())
|
| 368 |
+
|
| 369 |
+
if T_norm is not None:
|
| 370 |
+
means3D = means3D * T_norm.item()
|
| 371 |
+
scales = scales * T_norm.item()
|
| 372 |
+
|
| 373 |
+
# invert activation to make it compatible with the original ply format
|
| 374 |
+
opacity = torch.log(opacity/(1-opacity))
|
| 375 |
+
scales = torch.log(scales + 1e-8)
|
| 376 |
+
|
| 377 |
+
xyzs = means3D.detach() # .cpu().numpy()
|
| 378 |
+
f_dc = shs.detach().flatten(start_dim=1).contiguous() #.cpu().numpy()
|
| 379 |
+
opacities = opacity.detach() #.cpu().numpy()
|
| 380 |
+
scales = scales.detach() #.cpu().numpy()
|
| 381 |
+
rotations = rotations.detach() #.cpu().numpy()
|
| 382 |
+
|
| 383 |
+
l = ['x', 'y', 'z']
|
| 384 |
+
# All channels except the 3 DC
|
| 385 |
+
for i in range(f_dc.shape[1]):
|
| 386 |
+
l.append('f_dc_{}'.format(i))
|
| 387 |
+
l.append('opacity')
|
| 388 |
+
for i in range(scales.shape[1]):
|
| 389 |
+
l.append('scale_{}'.format(i))
|
| 390 |
+
for i in range(rotations.shape[1]):
|
| 391 |
+
l.append('rot_{}'.format(i))
|
| 392 |
+
|
| 393 |
+
dtype_full = [(attribute, 'f4') for attribute in l]
|
| 394 |
+
|
| 395 |
+
# 最优化方案:使用numpy的recarray直接创建
|
| 396 |
+
attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
|
| 397 |
+
|
| 398 |
+
# 使用recarray直接创建,避免循环和类型转换
|
| 399 |
+
elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
|
| 400 |
+
el = PlyElement.describe(elements, 'vertex')
|
| 401 |
+
|
| 402 |
+
print(path)
|
| 403 |
+
|
| 404 |
+
PlyData([el]).write(path)
|
| 405 |
+
|
| 406 |
+
# plydata = PlyData([el])
|
| 407 |
+
|
| 408 |
+
# vert = plydata["vertex"]
|
| 409 |
+
# sorted_indices = np.argsort(
|
| 410 |
+
# -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
|
| 411 |
+
# / (1 + np.exp(-vert["opacity"]))
|
| 412 |
+
# )
|
| 413 |
+
# buffer = BytesIO()
|
| 414 |
+
# for idx in sorted_indices:
|
| 415 |
+
# v = plydata["vertex"][idx]
|
| 416 |
+
# position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
|
| 417 |
+
# scales = np.exp(
|
| 418 |
+
# np.array(
|
| 419 |
+
# [v["scale_0"], v["scale_1"], v["scale_2"]],
|
| 420 |
+
# dtype=np.float32,
|
| 421 |
+
# )
|
| 422 |
+
# )
|
| 423 |
+
# rot = np.array(
|
| 424 |
+
# [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
|
| 425 |
+
# dtype=np.float32,
|
| 426 |
+
# )
|
| 427 |
+
# SH_C0 = 0.28209479177387814
|
| 428 |
+
# color = np.array(
|
| 429 |
+
# [
|
| 430 |
+
# 0.5 + SH_C0 * v["f_dc_0"],
|
| 431 |
+
# 0.5 + SH_C0 * v["f_dc_1"],
|
| 432 |
+
# 0.5 + SH_C0 * v["f_dc_2"],
|
| 433 |
+
# 1 / (1 + np.exp(-v["opacity"])),
|
| 434 |
+
# ]
|
| 435 |
+
# )
|
| 436 |
+
# buffer.write(position.tobytes())
|
| 437 |
+
# buffer.write(scales.tobytes())
|
| 438 |
+
# buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
|
| 439 |
+
# buffer.write(
|
| 440 |
+
# ((rot / np.linalg.norm(rot)) * 128 + 128)
|
| 441 |
+
# .clip(0, 255)
|
| 442 |
+
# .astype(np.uint8)
|
| 443 |
+
# .tobytes()
|
| 444 |
+
# )
|
| 445 |
+
|
| 446 |
+
# with open(path + '.splat', "wb") as f:
|
| 447 |
+
# f.write(buffer.getvalue())
|
| 448 |
+
|
| 449 |
+
@torch.amp.autocast(device_type="cuda", enabled=False)
|
| 450 |
+
def quaternion_slerp(
|
| 451 |
+
q0, q1, fraction, spin: int = 0, shortestpath: bool = True
|
| 452 |
+
):
|
| 453 |
+
"""Return spherical linear interpolation between two quaternions.
|
| 454 |
+
Args:
|
| 455 |
+
quat0: first quaternion
|
| 456 |
+
quat1: second quaternion
|
| 457 |
+
fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
|
| 458 |
+
spin: how much of an additional spin to place on the interpolation
|
| 459 |
+
shortestpath: whether to return the short or long path to rotation
|
| 460 |
+
"""
|
| 461 |
+
d = (q0 * q1).sum(-1)
|
| 462 |
+
if shortestpath:
|
| 463 |
+
# invert rotation
|
| 464 |
+
d[d < 0.0] = -d[d < 0.0]
|
| 465 |
+
q1[d < 0.0] = q1[d < 0.0]
|
| 466 |
+
|
| 467 |
+
_d = d.clamp(0, 1.0)
|
| 468 |
+
|
| 469 |
+
# theta = torch.arccos(d) * fraction
|
| 470 |
+
# q2 = q1 - q0 * d
|
| 471 |
+
# q2 = q2 / (q2.norm(dim=-1) + 1e-10)
|
| 472 |
+
|
| 473 |
+
# return torch.cos(theta) * q0 + torch.sin(theta) * q2
|
| 474 |
+
|
| 475 |
+
angle = torch.acos(_d) + spin * math.pi
|
| 476 |
+
isin = 1.0 / (torch.sin(angle)+ 1e-10)
|
| 477 |
+
q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None]
|
| 478 |
+
q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None]
|
| 479 |
+
|
| 480 |
+
q = q0_ + q1_
|
| 481 |
+
|
| 482 |
+
q[angle < 1e-5] = q0[angle < 1e-5]
|
| 483 |
+
# q[fraction < 1e-5] = q0[fraction < 1e-5]
|
| 484 |
+
# q[fraction > 1 - 1e-5] = q1[fraction > 1 - 1e-5]
|
| 485 |
+
# q[(d.abs() - 1).abs() < 1e-5] = q0[(d.abs() - 1).abs() < 1e-5]
|
| 486 |
+
|
| 487 |
+
return q
|
| 488 |
+
|
| 489 |
+
def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]):
|
| 490 |
+
"""
|
| 491 |
+
Args:
|
| 492 |
+
pose_a: first pose
|
| 493 |
+
pose_b: second pose
|
| 494 |
+
fraction
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
quat_a = pose_a[..., :4]
|
| 498 |
+
quat_b = pose_b[..., :4]
|
| 499 |
+
|
| 500 |
+
dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True)
|
| 501 |
+
quat_b = torch.where(dot < 0, -quat_b, quat_b)
|
| 502 |
+
|
| 503 |
+
quaternion = quaternion_slerp(quat_a, quat_b, fraction)
|
| 504 |
+
quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1)
|
| 505 |
+
|
| 506 |
+
T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:]
|
| 507 |
+
T = T + torch.randn_like(T) * noise_strengths[1]
|
| 508 |
+
|
| 509 |
+
new_pose = pose_a.clone()
|
| 510 |
+
new_pose[..., :4] = quaternion
|
| 511 |
+
new_pose[..., 4:] = T
|
| 512 |
+
return new_pose
|
| 513 |
+
|
| 514 |
+
def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]):
|
| 515 |
+
N, C = dense_cameras.shape
|
| 516 |
+
M = t.shape
|
| 517 |
+
|
| 518 |
+
left = torch.floor(t * (N-1)).long().clamp(0, N-2)
|
| 519 |
+
right = left + 1
|
| 520 |
+
fraction = t * (N-1) - left
|
| 521 |
+
|
| 522 |
+
a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C))
|
| 523 |
+
b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C))
|
| 524 |
+
|
| 525 |
+
new_pose = sample_from_two_pose(a[:, :7],
|
| 526 |
+
b[:, :7], fraction, noise_strengths=noise_strengths[:2])
|
| 527 |
+
|
| 528 |
+
new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:]
|
| 529 |
+
|
| 530 |
+
return torch.cat([new_pose, new_ins], dim=1)
|
| 531 |
+
|