imlixinyang commited on
Commit
c8df52d
·
1 Parent(s): 123eeba
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tmpfiles/
2
+ model.ckpt
3
+
4
+ **/__pycache__/**
README.md CHANGED
@@ -38,6 +38,8 @@ pip install torch torchvision
38
  pip install triton transformers pytorch_lightning omegaconf ninja numpy jaxtyping rich tensorboard einops moviepy==1.0.3 webdataset accelerate opencv-python lpips av plyfile ftfy peft tensorboard pandas flask
39
  ```
40
 
 
 
41
  - install ```gsplat@1.5.2``` and ```diffusers@wan-5Bi2v``` packages
42
  ```
43
  pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
@@ -55,9 +57,10 @@ cd FlashWorld
55
  python app.py
56
  ```
57
 
58
- Then, enjoy your journey in FlashWorld!
59
-
60
 
 
 
61
  ## More Generation Results
62
 
63
  [https://github.com/user-attachments/assets/bbdbe5de-5e15-4471-b380-4d8191688d82](https://github.com/user-attachments/assets/53d41748-4c35-48c4-9771-f458421c0b38)
@@ -67,7 +70,6 @@ Then, enjoy your journey in FlashWorld!
67
 
68
  Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)
69
 
70
-
71
  The code is released for academic research use only.
72
 
73
  If you have any questions, please contact me via [imlixinyang@gmail.com](mailto:imlixinyang@gmail.com).
 
38
  pip install triton transformers pytorch_lightning omegaconf ninja numpy jaxtyping rich tensorboard einops moviepy==1.0.3 webdataset accelerate opencv-python lpips av plyfile ftfy peft tensorboard pandas flask
39
  ```
40
 
41
+ Please refer to the `requirements.txt` file for the exact package versions.
42
+
43
  - install ```gsplat@1.5.2``` and ```diffusers@wan-5Bi2v``` packages
44
  ```
45
  pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
 
57
  python app.py
58
  ```
59
 
60
+ Then, open your web browser and navigate to ```http://HOST_IP:7860``` to start exploring FlashWorld!
 
61
 
62
+ <!-- We also provide example trajectory josn files and input images in the `examples/` directory. -->
63
+
64
  ## More Generation Results
65
 
66
  [https://github.com/user-attachments/assets/bbdbe5de-5e15-4471-b380-4d8191688d82](https://github.com/user-attachments/assets/53d41748-4c35-48c4-9771-f458421c0b38)
 
70
 
71
  Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)
72
 
 
73
  The code is released for academic research use only.
74
 
75
  If you have any questions, please contact me via [imlixinyang@gmail.com](mailto:imlixinyang@gmail.com).
app.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ GPU = spaces.GPU
4
+ print("spaces GPU is available")
5
+ except ImportError:
6
+ def GPU(func):
7
+ return func
8
+
9
+ import os
10
+ import subprocess
11
+
12
+ # def install_cuda_toolkit():
13
+ # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
14
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
15
+ # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
+ # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
+ # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
+ # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
+
20
+ # os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
+ # os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
+ # os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
23
+ # os.environ["CUDA_HOME"],
24
+ # "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
25
+ # )
26
+ # # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
+ # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
28
+
29
+ # print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
30
+
31
+ # subprocess.call('rm /usr/bin/gcc', shell=True)
32
+ # subprocess.call('rm /usr/bin/g++', shell=True)
33
+ # subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
34
+ # subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
35
+
36
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
37
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
38
+
39
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
40
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
41
+
42
+ # subprocess.call('gcc --version', shell=True)
43
+ # subprocess.call('g++ --version', shell=True)
44
+
45
+ # install_cuda_toolkit()
46
+
47
+ # subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
48
+
49
+ from flask import Flask, jsonify, request, send_file, render_template
50
+ import base64
51
+ import io
52
+ from PIL import Image
53
+ import torch
54
+ import numpy as np
55
+ import os
56
+ import argparse
57
+ import imageio
58
+ import json
59
+
60
+ import time
61
+ import threading
62
+
63
+ from concurrency_manager import ConcurrencyManager
64
+
65
+ from huggingface_hub import hf_hub_download
66
+
67
+ import einops
68
+ import torch
69
+ import torch.nn as nn
70
+ import torch.nn.functional as F
71
+ import numpy as np
72
+
73
+ import imageio
74
+
75
+ from models import *
76
+ from utils import *
77
+
78
+ from transformers import T5TokenizerFast, UMT5EncoderModel
79
+
80
+ from diffusers import FlowMatchEulerDiscreteScheduler
81
+
82
+ class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
83
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
84
+ if schedule_timesteps is None:
85
+ schedule_timesteps = self.timesteps
86
+
87
+ return torch.argmin(
88
+ (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
89
+
90
+ class GenerationSystem(nn.Module):
91
+ def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
92
+ super().__init__()
93
+ self.device = device
94
+ self.offload_t5 = offload_t5
95
+ self.offload_vae = offload_vae
96
+
97
+ self.latent_dim = 48
98
+ self.temporal_downsample_factor = 4
99
+ self.spatial_downsample_factor = 16
100
+
101
+ self.feat_dim = 1024
102
+
103
+ self.latent_patch_size = 2
104
+
105
+ self.denoising_steps = [0, 250, 500, 750]
106
+
107
+ model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
108
+
109
+ self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
110
+
111
+ from models.autoencoder_kl_wan import WanCausalConv3d
112
+ with torch.no_grad():
113
+ for name, module in self.vae.named_modules():
114
+ if isinstance(module, WanCausalConv3d):
115
+ time_pad = module._padding[4]
116
+ module.padding = (0, module._padding[2], module._padding[0])
117
+ module._padding = (0, 0, 0, 0, 0, 0)
118
+ module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
119
+
120
+ self.vae.requires_grad_(False)
121
+
122
+ self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
123
+ self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
124
+
125
+ self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std
126
+ self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean
127
+
128
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
129
+
130
+ self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
131
+
132
+ self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
133
+
134
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
135
+ # self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1]
136
+
137
+ weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
138
+ bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
139
+
140
+ extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
141
+ extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
142
+
143
+ self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
144
+ self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
145
+
146
+ self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
147
+
148
+ self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
149
+
150
+ self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
151
+
152
+ self.transformer.disable_gradient_checkpointing()
153
+ self.transformer.gradient_checkpointing = False
154
+
155
+ self.add_feedback_for_transformer()
156
+
157
+ if ckpt_path is not None:
158
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
159
+ self.transformer.load_state_dict(state_dict["transformer"])
160
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
161
+ print(f"Loaded {ckpt_path}.")
162
+
163
+ from quant import FluxFp8GeMMProcessor
164
+
165
+ FluxFp8GeMMProcessor(self.transformer)
166
+
167
+ del self.vae.post_quant_conv, self.vae.decoder
168
+ self.vae.to(self.device if not self.offload_vae else "cpu")
169
+
170
+ self.transformer.to(self.device)
171
+
172
+ def add_feedback_for_transformer(self):
173
+ self.use_feedback = True
174
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
175
+
176
+ def encode_text(self, texts):
177
+ max_sequence_length = 512
178
+
179
+ text_inputs = self.tokenizer(
180
+ texts,
181
+ padding="max_length",
182
+ max_length=max_sequence_length,
183
+ truncation=True,
184
+ add_special_tokens=True,
185
+ return_attention_mask=True,
186
+ return_tensors="pt",
187
+ )
188
+ if getattr(self, "offload_t5", False):
189
+ text_input_ids = text_inputs.input_ids.to("cpu")
190
+ mask = text_inputs.attention_mask.to("cpu")
191
+ else:
192
+ text_input_ids = text_inputs.input_ids.to(self.device)
193
+ mask = text_inputs.attention_mask.to(self.device)
194
+ seq_lens = mask.gt(0).sum(dim=1).long()
195
+
196
+ if getattr(self, "offload_t5", False):
197
+ with torch.no_grad():
198
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
199
+ else:
200
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
201
+ text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
202
+ text_embeds = torch.stack(
203
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
204
+ )
205
+ return text_embeds.float()
206
+
207
+ def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
208
+
209
+ out = self.transformer(
210
+ hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
211
+ timestep=t,
212
+ encoder_hidden_states=text_embeds,
213
+ return_dict=False,
214
+ )[0]
215
+
216
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
217
+
218
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
219
+ latents_pred_2d = noisy_latents - sigma * v_pred
220
+
221
+ if need_3d_mode:
222
+ scene_params = self.recon_decoder(
223
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
224
+ einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
225
+ cameras
226
+ ).flatten(1, -2)
227
+
228
+ images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
229
+
230
+ latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
231
+ einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
232
+ ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
233
+
234
+ return {
235
+ '2d': latents_pred_2d,
236
+ '3d': latents_pred_3d if need_3d_mode else None,
237
+ 'rgb_3d': images_pred if need_3d_mode else None,
238
+ 'scene': scene_params if need_3d_mode else None,
239
+ 'feat': feats
240
+ }
241
+
242
+ @torch.no_grad()
243
+ @torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
244
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
245
+ with torch.no_grad():
246
+ batch_size = 1
247
+
248
+ cameras = cameras.to(self.device).unsqueeze(0)
249
+
250
+ if cameras.shape[1] != n_frame:
251
+ render_cameras = cameras.clone()
252
+ cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
253
+ else:
254
+ render_cameras = cameras
255
+
256
+ cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
257
+
258
+ render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
259
+
260
+ text = "[Static] " + text
261
+
262
+ text_embeds = self.encode_text([text])
263
+ # neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1)
264
+
265
+ masks = torch.zeros(batch_size, n_frame, device=self.device)
266
+
267
+ condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
268
+
269
+ if image is not None:
270
+ image = image.to(self.device)
271
+
272
+ latent = self.latent_scale_fn(self.vae.encode(
273
+ image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
274
+ ).latent_dist.sample().to(self.device)).squeeze(2)
275
+
276
+ masks[:, image_index] = 1
277
+ condition_latents[:, :, image_index] = latent
278
+
279
+ raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
280
+ raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
281
+
282
+ noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
283
+
284
+ noisy_latents = noise
285
+
286
+ torch.cuda.empty_cache()
287
+
288
+ if self.use_feedback:
289
+ prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
290
+
291
+ prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
292
+
293
+ for i in range(len(self.denoising_steps)):
294
+ t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
295
+
296
+ t = self.timesteps[t_ids]
297
+
298
+ if self.use_feedback:
299
+ _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
300
+ else:
301
+ _condition_latents = condition_latents
302
+
303
+ if i < len(self.denoising_steps) - 1:
304
+ out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
305
+
306
+ latents_pred = out["3d"]
307
+
308
+ if self.use_feedback:
309
+ prev_latents_pred = latents_pred
310
+ prev_feats = out['feat']
311
+
312
+ noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
313
+
314
+ else:
315
+ out = self.transformer(
316
+ hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
317
+ timestep=t,
318
+ encoder_hidden_states=text_embeds,
319
+ return_dict=False,
320
+ )[0]
321
+
322
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
323
+
324
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
325
+ latents_pred = noisy_latents - sigma * v_pred
326
+
327
+ scene_params = self.recon_decoder(
328
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
329
+ einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
330
+ cameras
331
+ ).flatten(1, -2)
332
+
333
+ if video_output_path is not None:
334
+ interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
335
+
336
+ interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
337
+
338
+ interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
339
+
340
+ imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
341
+
342
+ scene_params = scene_params[0]
343
+
344
+ scene_params = scene_params.detach().cpu()
345
+
346
+ return scene_params, ref_w2c, T_norm
347
+
348
+ if __name__ == "__main__":
349
+ parser = argparse.ArgumentParser()
350
+ parser.add_argument('--port', type=int, default=7860)
351
+ parser.add_argument("--ckpt", default=None)
352
+ parser.add_argument("--gpu", type=int, default=0)
353
+ parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
354
+ parser.add_argument("--offload_t5", type=bool, default=False)
355
+ parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
356
+ args, _ = parser.parse_known_args()
357
+
358
+ # Ensure model.ckpt exists, download if not present
359
+ if args.ckpt is None:
360
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
361
+ ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
362
+ if not os.path.exists(ckpt_path):
363
+ hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
364
+ else:
365
+ ckpt_path = args.ckpt
366
+
367
+ app = Flask(__name__)
368
+
369
+ # 初始化GenerationSystem
370
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
371
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
372
+
373
+ # 初始化并发管理器
374
+ concurrency_manager = ConcurrencyManager(max_concurrent=args.max_concurrent)
375
+
376
+ @app.after_request
377
+ def after_request(response):
378
+ response.headers.add('Access-Control-Allow-Origin', '*')
379
+ response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
380
+ response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
381
+ return response
382
+
383
+ @GPU
384
+ def generate_wrapper(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None):
385
+ """生成函数的包装器,用于并发控制"""
386
+ return generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path)
387
+
388
+ def job_generate(file_id, cache_dir, payload):
389
+ """工作线程执行的生成任务:负责生成并落盘,返回可下载信息"""
390
+ # 解包参数
391
+ cameras = payload["cameras"]
392
+ n_frame = payload["n_frame"]
393
+ image = payload["image"]
394
+ text_prompt = payload["text_prompt"]
395
+ image_index = payload["image_index"]
396
+ image_height = payload["image_height"]
397
+ image_width = payload["image_width"]
398
+ data = payload["raw_request"]
399
+
400
+ # 执行生成
401
+ scene_params, ref_w2c, T_norm = generation_system.generate(
402
+ cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None
403
+ )
404
+
405
+ # 保存请求元数据
406
+ with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
407
+ json.dump(data, f)
408
+
409
+ # 导出PLY文件
410
+ splat_path = os.path.join(cache_dir, f'{file_id}.ply')
411
+ export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
412
+
413
+ file_size = os.path.getsize(splat_path) if os.path.exists(splat_path) else 0
414
+
415
+ return {
416
+ 'file_id': file_id,
417
+ 'file_path': splat_path,
418
+ 'file_size': file_size,
419
+ 'download_url': f'/download/{file_id}'
420
+ }
421
+
422
+ @app.route('/generate', methods=['POST', 'OPTIONS'])
423
+ def generate():
424
+ # Handle preflight request
425
+ if request.method == 'OPTIONS':
426
+ return jsonify({'status': 'ok'})
427
+
428
+ try:
429
+ data = request.get_json(force=True)
430
+
431
+ image_prompt = data.get('image_prompt', None)
432
+ text_prompt = data.get('text_prompt', "")
433
+ cameras = data.get('cameras')
434
+ resolution = data.get('resolution')
435
+ image_index = data.get('image_index', 0)
436
+
437
+ n_frame, image_height, image_width = resolution
438
+
439
+ if not image_prompt and text_prompt == "":
440
+ return jsonify({'error': 'No Prompts provided'}), 400
441
+
442
+ # 处理图像
443
+ if image_prompt:
444
+ # image_prompt可以是路径和base64
445
+ if os.path.exists(image_prompt):
446
+ image_prompt = Image.open(image_prompt)
447
+ else:
448
+ # image_prompt 可能是 "data:image/png;base64,...."
449
+ if ',' in image_prompt:
450
+ image_prompt = image_prompt.split(',', 1)[1]
451
+
452
+ try:
453
+ image_bytes = base64.b64decode(image_prompt)
454
+ image_prompt = Image.open(io.BytesIO(image_bytes))
455
+ except Exception as img_e:
456
+ return jsonify({'error': f'Image decode error: {str(img_e)}'}), 400
457
+
458
+ image = image_prompt.convert('RGB')
459
+
460
+ w, h = image.size
461
+
462
+ # center crop
463
+ if image_height / h > image_width / w:
464
+ scale = image_height / h
465
+ else:
466
+ scale = image_width / w
467
+
468
+ new_h = int(image_height / scale)
469
+ new_w = int(image_width / scale)
470
+
471
+ image = image.crop(((w - new_w) // 2, (h - new_h) // 2,
472
+ new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height))
473
+
474
+ for camera in cameras:
475
+ camera['fx'] = camera['fx'] * scale
476
+ camera['fy'] = camera['fy'] * scale
477
+ camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale
478
+ camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale
479
+
480
+ image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1
481
+ else:
482
+ image = None
483
+
484
+ cameras = torch.stack([
485
+ torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32))
486
+ for camera in cameras
487
+ ], dim=0)
488
+
489
+ file_id = str(int(time.time() * 1000))
490
+
491
+ # 组装任务参数,推迟执行与落盘到工作线程中
492
+ payload = {
493
+ 'cameras': cameras,
494
+ 'n_frame': n_frame,
495
+ 'image': image,
496
+ 'text_prompt': text_prompt,
497
+ 'image_index': image_index,
498
+ 'image_height': image_height,
499
+ 'image_width': image_width,
500
+ 'raw_request': data,
501
+ }
502
+
503
+ # 提交任务到并发管理器(异步)
504
+ task_id = concurrency_manager.submit_task(
505
+ job_generate, file_id, args.cache_dir, payload
506
+ )
507
+
508
+ # 提交后立即返回队列信息
509
+ queue_status = concurrency_manager.get_queue_status()
510
+ queued_tasks = queue_status.get('queued_tasks', [])
511
+ try:
512
+ queue_position = queued_tasks.index(task_id) + 1
513
+ except ValueError:
514
+ # 如果任务已被工作线程立即领取,则认为已开始执行,位置为 0
515
+ queue_position = 0
516
+
517
+ return jsonify({
518
+ 'success': True,
519
+ 'task_id': task_id,
520
+ 'file_id': file_id,
521
+ 'queue': {
522
+ 'queued_count': queue_status.get('queued_count', 0),
523
+ 'running_count': queue_status.get('running_count', 0),
524
+ 'position': queue_position
525
+ }
526
+ }), 202
527
+
528
+ except Exception as e:
529
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
530
+
531
+ @app.route('/download/<file_id>', methods=['GET'])
532
+ def download_file(file_id):
533
+ """下载生成的PLY文件"""
534
+ file_path = os.path.join(args.cache_dir, f'{file_id}.ply')
535
+
536
+ if not os.path.exists(file_path):
537
+ return jsonify({'error': 'File not found'}), 404
538
+
539
+ return send_file(file_path, as_attachment=True, download_name=f'{file_id}.ply')
540
+
541
+ @app.route('/delete/<file_id>', methods=['DELETE', 'POST', 'OPTIONS'])
542
+ def delete_file_endpoint(file_id):
543
+ """删除生成的文件及其元数据(由前端在下载完成后调用)"""
544
+ # CORS preflight
545
+ if request.method == 'OPTIONS':
546
+ return jsonify({'status': 'ok'})
547
+
548
+ try:
549
+ ply_path = os.path.join(args.cache_dir, f'{file_id}.ply')
550
+ json_path = os.path.join(args.cache_dir, f'{file_id}.json')
551
+ deleted = []
552
+ for path in [ply_path, json_path]:
553
+ if os.path.exists(path):
554
+ os.remove(path)
555
+ deleted.append(os.path.basename(path))
556
+ return jsonify({'success': True, 'deleted': deleted})
557
+ except Exception as e:
558
+ return jsonify({'success': False, 'error': str(e)}), 500
559
+
560
+ @app.route('/status', methods=['GET'])
561
+ def get_status():
562
+ """获取系统状态和队列信息"""
563
+ try:
564
+ queue_status = concurrency_manager.get_queue_status()
565
+ return jsonify({
566
+ 'success': True,
567
+ 'status': queue_status,
568
+ 'timestamp': time.time()
569
+ })
570
+ except Exception as e:
571
+ return jsonify({'error': f'Failed to get status: {str(e)}'}), 500
572
+
573
+ @app.route('/task/<task_id>', methods=['GET'])
574
+ def get_task_status(task_id):
575
+ """获取特定任务的状态(包含排队位置和完成后的文件信息)"""
576
+ try:
577
+ task = concurrency_manager.get_task_status(task_id)
578
+ if not task:
579
+ return jsonify({'error': 'Task not found'}), 404
580
+
581
+ queue_status = concurrency_manager.get_queue_status()
582
+ queued_tasks = queue_status.get('queued_tasks', [])
583
+ try:
584
+ queue_position = queued_tasks.index(task_id) + 1
585
+ except ValueError:
586
+ queue_position = 0
587
+
588
+ resp = {
589
+ 'success': True,
590
+ 'task_id': task_id,
591
+ 'status': task.status.value,
592
+ 'created_at': task.created_at,
593
+ 'started_at': task.started_at,
594
+ 'completed_at': task.completed_at,
595
+ 'error': task.error,
596
+ 'queue': {
597
+ 'queued_count': queue_status.get('queued_count', 0),
598
+ 'running_count': queue_status.get('running_count', 0),
599
+ 'position': queue_position
600
+ }
601
+ }
602
+
603
+ if task.status.value == 'completed' and isinstance(task.result, dict):
604
+ resp.update({
605
+ 'file_id': task.result.get('file_id'),
606
+ 'file_path': task.result.get('file_path'),
607
+ 'file_size': task.result.get('file_size'),
608
+ 'download_url': task.result.get('download_url'),
609
+ 'generation_time': (task.completed_at - task.started_at)
610
+ })
611
+
612
+ # 更新task状态
613
+
614
+ return jsonify(resp)
615
+ except Exception as e:
616
+ return jsonify({'error': f'Failed to get task status: {str(e)}'}), 500
617
+
618
+ @app.route("/")
619
+ def index():
620
+ return send_file("index.html")
621
+
622
+ os.makedirs(args.cache_dir, exist_ok=True)
623
+
624
+ # 后台定时清理:删除超过30分钟未访问/修改的缓存文件
625
+ def cleanup_worker(cache_dir: str, max_age_seconds: int = 1800, interval_seconds: int = 300):
626
+ while True:
627
+ try:
628
+ now = time.time()
629
+ for name in os.listdir(cache_dir):
630
+ # 只清理与任务相关的 .ply/.json 文件
631
+ if not (name.endswith('.ply') or name.endswith('.json')):
632
+ continue
633
+ path = os.path.join(cache_dir, name)
634
+ try:
635
+ mtime = os.path.getmtime(path)
636
+ if now - mtime > max_age_seconds:
637
+ os.remove(path)
638
+ except FileNotFoundError:
639
+ pass
640
+ except Exception:
641
+ # 忽略单个文件的异常,继续清理
642
+ pass
643
+ except Exception:
644
+ # 防止线程因异常退出
645
+ pass
646
+ time.sleep(interval_seconds)
647
+
648
+ cleaner_thread = threading.Thread(target=cleanup_worker, args=(args.cache_dir,), daemon=True)
649
+ cleaner_thread.start()
650
+
651
+ app.run(host='0.0.0.0', port=args.port)
concurrency_manager.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ import uuid
4
+ from typing import Dict, List, Optional, Callable, Any
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+
8
+ class TaskStatus(Enum):
9
+ QUEUED = "queued"
10
+ RUNNING = "running"
11
+ COMPLETED = "completed"
12
+ FAILED = "failed"
13
+
14
+ @dataclass
15
+ class Task:
16
+ task_id: str
17
+ status: TaskStatus
18
+ created_at: float
19
+ started_at: Optional[float] = None
20
+ completed_at: Optional[float] = None
21
+ result: Optional[Any] = None
22
+ error: Optional[str] = None
23
+ function: Optional[Callable] = None
24
+ args: tuple = ()
25
+ kwargs: dict = None
26
+
27
+ def __post_init__(self):
28
+ if self.kwargs is None:
29
+ self.kwargs = {}
30
+
31
+ class ConcurrencyManager:
32
+ def __init__(self, max_concurrent: int = 2):
33
+ """
34
+ 并发控制管理器
35
+
36
+ Args:
37
+ max_concurrent: 最大并发数量
38
+ """
39
+ self.max_concurrent = max_concurrent
40
+ self.running_tasks: Dict[str, Task] = {}
41
+ self.queued_tasks: List[Task] = []
42
+ self.completed_tasks: Dict[str, Task] = {}
43
+ self.lock = threading.RLock()
44
+ self.worker_threads: List[threading.Thread] = []
45
+ self.shutdown_event = threading.Event()
46
+
47
+ # 启动工作线程
48
+ self._start_workers()
49
+
50
+ def _start_workers(self):
51
+ """启动工作线程"""
52
+ for i in range(self.max_concurrent):
53
+ worker = threading.Thread(target=self._worker_loop, daemon=True)
54
+ worker.start()
55
+ self.worker_threads.append(worker)
56
+
57
+ def _worker_loop(self):
58
+ """工作线程主循环"""
59
+ while not self.shutdown_event.is_set():
60
+ try:
61
+ task = self._get_next_task()
62
+ if task:
63
+ self._execute_task(task)
64
+ else:
65
+ # 没有任务时短暂休眠
66
+ time.sleep(0.1)
67
+ except Exception as e:
68
+ print(f"Worker thread error: {e}")
69
+ time.sleep(1)
70
+
71
+ def _get_next_task(self) -> Optional[Task]:
72
+ """获取下一个要执行的任务"""
73
+ with self.lock:
74
+ if self.queued_tasks:
75
+ return self.queued_tasks.pop(0)
76
+ return None
77
+
78
+ def _execute_task(self, task: Task):
79
+ """执行任务"""
80
+ try:
81
+ with self.lock:
82
+ task.status = TaskStatus.RUNNING
83
+ task.started_at = time.time()
84
+ self.running_tasks[task.task_id] = task
85
+
86
+ # 执行任务
87
+ if task.function:
88
+ result = task.function(*task.args, **task.kwargs)
89
+ task.result = result
90
+
91
+ # 标记完成
92
+ with self.lock:
93
+ task.status = TaskStatus.COMPLETED
94
+ task.completed_at = time.time()
95
+ self.completed_tasks[task.task_id] = task
96
+ if task.task_id in self.running_tasks:
97
+ del self.running_tasks[task.task_id]
98
+
99
+ except Exception as e:
100
+ # 标记失败
101
+ with self.lock:
102
+ task.status = TaskStatus.FAILED
103
+ task.completed_at = time.time()
104
+ task.error = str(e)
105
+ self.completed_tasks[task.task_id] = task
106
+ if task.task_id in self.running_tasks:
107
+ del self.running_tasks[task.task_id]
108
+
109
+ def submit_task(self, func: Callable, *args, **kwargs) -> str:
110
+ """
111
+ 提交任务
112
+
113
+ Args:
114
+ func: 要执行的函数
115
+ *args: 函数参数
116
+ **kwargs: 函数关键字参数
117
+
118
+ Returns:
119
+ task_id: 任务ID
120
+ """
121
+ task_id = str(uuid.uuid4())
122
+ task = Task(
123
+ task_id=task_id,
124
+ status=TaskStatus.QUEUED,
125
+ created_at=time.time(),
126
+ function=func,
127
+ args=args,
128
+ kwargs=kwargs
129
+ )
130
+
131
+ with self.lock:
132
+ self.queued_tasks.append(task)
133
+
134
+ return task_id
135
+
136
+ def get_task_status(self, task_id: str) -> Optional[Task]:
137
+ """获取任务状态"""
138
+ with self.lock:
139
+ if task_id in self.running_tasks:
140
+ return self.running_tasks[task_id]
141
+ elif task_id in self.completed_tasks:
142
+ return self.completed_tasks[task_id]
143
+ else:
144
+ # 检查队列中的任务
145
+ for task in self.queued_tasks:
146
+ if task.task_id == task_id:
147
+ return task
148
+ return None
149
+
150
+ def get_queue_status(self) -> Dict[str, Any]:
151
+ """获取队列状态"""
152
+ with self.lock:
153
+ return {
154
+ "max_concurrent": self.max_concurrent,
155
+ "running_count": len(self.running_tasks),
156
+ "queued_count": len(self.queued_tasks),
157
+ "completed_count": len(self.completed_tasks),
158
+ "running_tasks": [task.task_id for task in self.running_tasks.values()],
159
+ "queued_tasks": [task.task_id for task in self.queued_tasks],
160
+ }
161
+
162
+ def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Task:
163
+ """
164
+ 等待任务完成
165
+
166
+ Args:
167
+ task_id: 任务ID
168
+ timeout: 超时时间(秒),None表示无限等待
169
+
170
+ Returns:
171
+ Task: 完成的任务
172
+ """
173
+ start_time = time.time()
174
+
175
+ while True:
176
+ task = self.get_task_status(task_id)
177
+ if task and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
178
+ return task
179
+
180
+ if timeout and (time.time() - start_time) > timeout:
181
+ raise TimeoutError(f"Task {task_id} timed out after {timeout} seconds")
182
+
183
+ time.sleep(0.1)
184
+
185
+ def cleanup_old_tasks(self, max_age_hours: int = 24):
186
+ """清理旧任务"""
187
+ current_time = time.time()
188
+ max_age_seconds = max_age_hours * 3600
189
+
190
+ with self.lock:
191
+ # 清理已完成的任务
192
+ old_tasks = [
193
+ task_id for task_id, task in self.completed_tasks.items()
194
+ if current_time - task.completed_at > max_age_seconds
195
+ ]
196
+ for task_id in old_tasks:
197
+ del self.completed_tasks[task_id]
198
+
199
+ def shutdown(self):
200
+ """关闭管理器"""
201
+ self.shutdown_event.set()
202
+ for worker in self.worker_threads:
203
+ worker.join(timeout=5)
index.html ADDED
@@ -0,0 +1,2130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>FlashWorld Demo</title>
7
+ <meta name="description" content="">
8
+ <style>
9
+ body {
10
+ margin: 0;
11
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
12
+ background: #1a1a1a;
13
+ color: #ffffff;
14
+ overflow: hidden;
15
+ }
16
+
17
+ .main-container {
18
+ display: flex;
19
+ height: 100vh;
20
+ flex-direction: column;
21
+ }
22
+
23
+ .header {
24
+ background: rgba(0, 0, 0, 0.8);
25
+ padding: 15px 20px;
26
+ text-align: center;
27
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1);
28
+ flex-shrink: 0;
29
+ }
30
+
31
+ .header h1 {
32
+ margin: 0;
33
+ color: white;
34
+ font-size: 1.8em;
35
+ font-weight: 600;
36
+ margin-bottom: 8px;
37
+ }
38
+ .header-title-wrap {
39
+ display: inline-flex;
40
+ align-items: center;
41
+ gap: 8px;
42
+ position: relative;
43
+ }
44
+
45
+ .header-links {
46
+ display: flex;
47
+ justify-content: center;
48
+ gap: 20px;
49
+ margin-top: 8px;
50
+ }
51
+
52
+ .header-links a {
53
+ color: #60a5fa;
54
+ text-decoration: none;
55
+ font-size: 0.9em;
56
+ padding: 5px 10px;
57
+ border: 1px solid #60a5fa;
58
+ border-radius: 5px;
59
+ transition: all 0.3s ease;
60
+ }
61
+
62
+ .header-links a:hover {
63
+ background: #60a5fa;
64
+ color: white;
65
+ }
66
+
67
+ .content-container {
68
+ display: flex;
69
+ flex: 1;
70
+ overflow: hidden;
71
+ }
72
+
73
+ .left-panel {
74
+ width: 280px;
75
+ background: rgba(0, 0, 0, 0.7);
76
+ border-right: 1px solid rgba(255, 255, 255, 0.1);
77
+ padding: 20px;
78
+ overflow-y: auto;
79
+ flex-shrink: 0;
80
+ }
81
+
82
+ .center-panel {
83
+ flex: 1;
84
+ position: relative;
85
+ background: #000;
86
+ display: flex;
87
+ justify-content: center;
88
+ align-items: center;
89
+ }
90
+
91
+ .right-panel {
92
+ width: 300px;
93
+ background: rgba(0, 0, 0, 0.7);
94
+ border-left: 1px solid rgba(255, 255, 255, 0.1);
95
+ padding: 20px;
96
+ overflow-y: auto;
97
+ flex-shrink: 0;
98
+ }
99
+
100
+ .guidance {
101
+ color: #e5e7eb;
102
+ }
103
+
104
+ .guidance h2 {
105
+ color: #ffffff;
106
+ margin-top: 0;
107
+ font-size: 1.3em;
108
+ border-bottom: 2px solid #60a5fa;
109
+ padding-bottom: 8px;
110
+ margin-bottom: 20px;
111
+ }
112
+
113
+ .gui-container h2{
114
+ color: #ffffff;
115
+ margin-top: 0;
116
+ font-size: 1.3em;
117
+ border-bottom: 2px solid #60fae5;
118
+ padding-bottom: 8px;
119
+ margin-bottom: 20px;
120
+ }
121
+
122
+ .step {
123
+ margin: 12px 0;
124
+ padding: 12px;
125
+ background: rgba(96, 165, 250, 0.1);
126
+ border-radius: 6px;
127
+ border-left: 3px solid #60a5fa;
128
+ }
129
+
130
+ .step h3 {
131
+ margin: 0 0 8px 0;
132
+ color: #ffffff;
133
+ font-size: 1em;
134
+ }
135
+
136
+ .step p {
137
+ margin: 4px 0;
138
+ line-height: 1.4;
139
+ font-size: 0.85em;
140
+ color: #d1d5db;
141
+ }
142
+
143
+ .controls-info {
144
+ background: rgba(168, 85, 247, 0.1);
145
+ border-left: 3px solid #a855f7;
146
+ }
147
+
148
+ .keyboard-shortcuts {
149
+ background: rgba(34, 197, 94, 0.1);
150
+ border-left: 3px solid #22c55e;
151
+ }
152
+
153
+ .loading {
154
+ position: absolute;
155
+ top: 50%;
156
+ left: 50%;
157
+ min-width: 300px;
158
+ min-height: 200px;
159
+ transform: translate(-50%, -50%);
160
+ background: rgba(0, 0, 0, 0.9);
161
+ color: white;
162
+ padding: 20px;
163
+ border-radius: 10px;
164
+ display: none;
165
+ z-index: 1000;
166
+ text-align: center;
167
+ vertical-align: middle;
168
+ }
169
+
170
+ .generation-info {
171
+ background: rgba(34, 197, 94, 0.1);
172
+ border: 1px solid #22c55e;
173
+ border-radius: 8px;
174
+ padding: 15px;
175
+ margin: 10px 0;
176
+ color: #22c55e;
177
+ font-family: 'Courier New', monospace;
178
+ font-size: 0.9em;
179
+ }
180
+
181
+ .progress-container {
182
+ width: 100%;
183
+ background: rgba(255, 255, 255, 0.1);
184
+ border-radius: 10px;
185
+ overflow: hidden;
186
+ margin: 10px 0;
187
+ position: relative;
188
+ }
189
+
190
+ .progress-bar {
191
+ height: 20px;
192
+ background: linear-gradient(90deg, #60a5fa, #3b82f6);
193
+ width: 0%;
194
+ transition: width 0.3s ease;
195
+ border-radius: 10px;
196
+ position: relative;
197
+ }
198
+
199
+ .progress-text {
200
+ position: absolute;
201
+ top: 50%;
202
+ left: 50%;
203
+ transform: translate(-50%, -50%);
204
+ color: white;
205
+ font-weight: bold;
206
+ font-size: 0.8em;
207
+ white-space: nowrap;
208
+ }
209
+
210
+ /* Info tooltip */
211
+ .info-tip {
212
+ display: inline-block;
213
+ position: relative;
214
+ margin-left: 8px;
215
+ width: 16px;
216
+ height: 16px;
217
+ line-height: 16px;
218
+ text-align: center;
219
+ border-radius: 50%;
220
+ background: #3b82f6;
221
+ color: #fff;
222
+ font-size: 12px;
223
+ cursor: default;
224
+ user-select: none;
225
+ }
226
+ .info-tip .tooltip {
227
+ display: none;
228
+ position: absolute;
229
+ left: 0;
230
+ top: calc(100% + 8px); /* show below the icon */
231
+ transform: none;
232
+ background: rgba(0,0,0,0.9);
233
+ color: #e5e7eb;
234
+ border: 1px solid rgba(255,255,255,0.15);
235
+ border-radius: 8px;
236
+ padding: 10px 12px;
237
+ font-size: 12px;
238
+ width: 360px; /* wider tooltip */
239
+ white-space: normal;
240
+ z-index: 2000; /* above GUI and other elements */
241
+ box-shadow: 0 4px 12px rgba(0,0,0,0.4);
242
+ }
243
+ .info-tip:hover .tooltip {
244
+ display: block;
245
+ }
246
+
247
+ .status-bar {
248
+ background: rgba(0, 0, 0, 0.9);
249
+ color: #60a5fa;
250
+ padding: 8px 15px;
251
+ font-family: 'Courier New', monospace;
252
+ font-size: 0.8em;
253
+ border-top: 1px solid rgba(255, 255, 255, 0.1);
254
+ flex-shrink: 0;
255
+ }
256
+
257
+ .canvas-container {
258
+ width: 100%;
259
+ height: 100%;
260
+ display: flex;
261
+ justify-content: center;
262
+ align-items: center;
263
+ background:
264
+ repeating-linear-gradient(
265
+ 45deg,
266
+ #1a1a1a 0px,
267
+ #1a1a1a 10px,
268
+ #2a2a2a 10px,
269
+ #2a2a2a 20px
270
+ );
271
+ position: relative;
272
+ }
273
+
274
+ .canvas-wrapper {
275
+ position: relative;
276
+ border: 2px solid #444;
277
+ background: #111;
278
+ box-shadow:
279
+ 0 0 20px rgba(0, 0, 0, 0.5),
280
+ inset 0 0 10px rgba(0, 0, 0, 0.3);
281
+ border-radius: 4px;
282
+ }
283
+
284
+ .canvas-wrapper canvas {
285
+ display: block;
286
+ border-radius: 2px;
287
+ }
288
+
289
+ /* Add a subtle animation to the canvas wrapper */
290
+ .canvas-wrapper:hover {
291
+ border-color: #666;
292
+ box-shadow:
293
+ 0 0 30px rgba(0, 0, 0, 0.7),
294
+ inset 0 0 15px rgba(0, 0, 0, 0.4);
295
+ }
296
+
297
+ /* Progress & status beautify */
298
+ .progress-container {
299
+ width: 100%;
300
+ height: 18px;
301
+ background: linear-gradient(180deg, rgba(255,255,255,0.06), rgba(255,255,255,0.02));
302
+ border: 1px solid rgba(255,255,255,0.12);
303
+ border-radius: 999px;
304
+ overflow: hidden;
305
+ box-shadow: 0 2px 10px rgba(0,0,0,0.35) inset;
306
+ position: relative;
307
+ }
308
+ .progress-bar {
309
+ height: 100%;
310
+ background: linear-gradient(90deg, #60a5fa, #8b5cf6);
311
+ box-shadow: 0 0 10px rgba(96,165,250,0.65);
312
+ position: relative;
313
+ transition: width .15s ease;
314
+ }
315
+ .progress-text {
316
+ position: absolute;
317
+ top: 50%;
318
+ left: 50%;
319
+ transform: translate(-50%, -50%);
320
+ font-size: 11px;
321
+ color: #f8fafc;
322
+ text-shadow: 0 1px 2px rgba(0,0,0,0.5);
323
+ pointer-events: none;
324
+ white-space: nowrap;
325
+ }
326
+
327
+ .status-badges {
328
+ display: flex;
329
+ gap: 8px;
330
+ flex-wrap: wrap;
331
+ margin-top: 8px;
332
+ }
333
+ .badge {
334
+ display: inline-flex;
335
+ align-items: center;
336
+ gap: 6px;
337
+ padding: 6px 10px;
338
+ border-radius: 8px;
339
+ font-size: 12px;
340
+ border: 1px solid rgba(255,255,255,0.12);
341
+ background: rgba(255,255,255,0.06);
342
+ }
343
+ .badge .dot { width: 8px; height: 8px; border-radius: 999px; }
344
+ .badge.queue .dot { background: #f59e0b; }
345
+ .badge.running .dot { background: #22c55e; }
346
+ .badge.time .dot { background: #60a5fa; }
347
+ .badge.bytes .dot { background: #a78bfa; }
348
+
349
+ .details-grid {
350
+ display: grid;
351
+ grid-template-columns: repeat(2, minmax(0, 1fr));
352
+ gap: 6px 12px;
353
+ margin-top: 8px;
354
+ font-size: 12px;
355
+ color: #cbd5e1;
356
+ }
357
+ .details-grid div { opacity: 0.9; }
358
+
359
+ /* Canvas resizing indicator */
360
+ .canvas-wrapper.resizing {
361
+ border-color: #60a5fa;
362
+ box-shadow:
363
+ 0 0 25px rgba(96, 165, 250, 0.3),
364
+ inset 0 0 10px rgba(96, 165, 250, 0.1);
365
+ }
366
+
367
+ .canvas-wrapper.resizing::after {
368
+ content: "Resizing...";
369
+ position: absolute;
370
+ top: 50%;
371
+ left: 50%;
372
+ transform: translate(-50%, -50%);
373
+ color: #60a5fa;
374
+ font-size: 12px;
375
+ font-weight: bold;
376
+ z-index: 10;
377
+ pointer-events: none;
378
+ }
379
+
380
+ /* GUI Panel Styling */
381
+ .gui-panel {
382
+ background: rgba(0, 0, 0, 0.8);
383
+ border-radius: 8px;
384
+ padding: 15px;
385
+ min-height: 400px;
386
+ }
387
+
388
+ .gui-panel .lil-gui {
389
+ --background-color: rgba(0, 0, 0, 0.8);
390
+ --text-color: #ffffff;
391
+ --title-background-color: rgba(96, 165, 250, 0.2);
392
+ --title-text-color: #ffffff;
393
+ --widget-color: rgba(96, 165, 250, 0.3);
394
+ --hover-color: rgba(96, 165, 250, 0.5);
395
+ }
396
+
397
+ /* Ensure GUI is visible */
398
+ .lil-gui {
399
+ position: relative !important;
400
+ z-index: 1000 !important;
401
+ }
402
+
403
+ @media (max-width: 1200px) {
404
+ .left-panel {
405
+ width: 250px;
406
+ }
407
+
408
+ .right-panel {
409
+ width: 280px;
410
+ }
411
+ }
412
+
413
+ @media (max-width: 768px) {
414
+ .content-container {
415
+ flex-direction: column;
416
+ }
417
+
418
+ .left-panel, .right-panel {
419
+ width: 100%;
420
+ height: auto;
421
+ max-height: 200px;
422
+ }
423
+
424
+ .center-panel {
425
+ flex: 1;
426
+ min-height: 400px;
427
+ }
428
+ }
429
+ </style>
430
+ <script type="importmap">
431
+ {
432
+ "imports": {
433
+ "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.174.0/three.module.js",
434
+ "@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.6/spark.module.js",
435
+ "lil-gui": "https://cdn.jsdelivr.net/npm/lil-gui@0.20/+esm"
436
+ }
437
+ }
438
+ </script>
439
+ </head>
440
+ <body>
441
+ <div class="main-container">
442
+ <!-- Header Section -->
443
+ <header class="header">
444
+ <div style="display: flex; justify-content: space-between; align-items: center; width: 100%;">
445
+ <h1 style="margin: 0; flex: 1; text-align: left;">
446
+ <span class="header-title-wrap">FlashWorld Spark Demo
447
+ <span class="info-tip">!
448
+ <span class="tooltip" style="max-width: 260px; text-align: left;">Note: Front-end real-time rend ering in Spark uses compressed Gaussian Splat attributes. Visual quality in this demo may be lower than offline/back-end rendering.
449
+ Also, the generation is fast but the downloading may be slow, please be patient.
450
+ </span>
451
+ </span>
452
+ </span>
453
+ </h1>
454
+ <div class="header-links" style="margin-left: 20px;">
455
+ <a href="#" target="_blank">Paper</a>
456
+ <a href="#" target="_blank">Code</a>
457
+ <a href="#" target="_blank">Project Page</a>
458
+ </div>
459
+ </div>
460
+ </header>
461
+
462
+ <!-- Main Content Container -->
463
+ <div class="content-container">
464
+ <!-- Left Panel: Simplified Guidance -->
465
+ <div class="left-panel">
466
+ <div class="guidance">
467
+ <h2>Instructions</h2>
468
+
469
+ <div class="step">
470
+ <h3>1. Configure</h3>
471
+ <p>Set FOV and Resolution and Click "Fix Configurations"</p>
472
+ </div>
473
+
474
+
475
+ <div class="step">
476
+ <h3>2. Set Camera Trajectory</h3>
477
+ <p><b>Manual:</b> Navigate with mouse and keyboard, press <kbd>Space</kbd> to record</p>
478
+ <p><b>Template:</b> Select template type and click "Generate Trajectory"</p>
479
+ <p><b>JSON:</b> Load trajectory from JSON file</p>
480
+ </div>
481
+
482
+ <div class="step">
483
+ <h3>3. Add Prompts</h3>
484
+ <p>Upload image or enter text description</p>
485
+ </div>
486
+
487
+ <div class="step">
488
+ <h3>4. Generate</h3>
489
+ <p>Click "Generate!" to create your scene</p>
490
+ </div>
491
+
492
+ <div class="step controls-info">
493
+ <h3>Controls</h3>
494
+ <p><strong>Mouse/QE:</strong> Rotate view</p>
495
+ <p><strong>WASD/RF:</strong> Move</p>
496
+ <p><strong>Space:</strong> Record camera</p>
497
+ </div>
498
+
499
+ </div>
500
+ </div>
501
+
502
+ <!-- Center Panel: Canvas -->
503
+ <div class="center-panel">
504
+ <div class="canvas-container" id="canvas-container">
505
+ <div class="canvas-wrapper" id="canvas-wrapper">
506
+ <div class="loading" id="loading">
507
+ <h3>🎬 Generating Scene...</h3>
508
+ <p>Please wait while we create your 3D scene</p>
509
+ <div id="generation-info" class="generation-info" style="display: none;">
510
+ <div><strong>Generation Time:</strong> <span id="generation-time">-</span> seconds</div>
511
+ <div><strong>File Size:</strong> <span id="file-size">-</span> MB</div>
512
+ </div>
513
+ <div id="download-progress" style="display: none;">
514
+ <div class="progress-container">
515
+ <div class="progress-bar" id="progress-bar"></div>
516
+ <div class="progress-text" id="progress-text">0%</div>
517
+ </div>
518
+ <div class="status-badges" id="status-badges" style="display: none;">
519
+ <div class="badge queue" id="badge-queue"><span class="dot"></span><span id="badge-queue-text">Queue</span></div>
520
+ <div class="badge running" id="badge-running" style="display: none;"><span class="dot"></span><span id="badge-running-text">Running</span></div>
521
+ <div class="badge time" id="badge-time" style="display: none;"><span class="dot"></span><span id="badge-time-text">00:00</span></div>
522
+ </div>
523
+ <div id="queue-details" class="details-grid" style="display: none;"></div>
524
+ <div id="download-details" class="details-grid" style="display: none;"></div>
525
+ </div>
526
+ </div>
527
+ </div>
528
+ </div>
529
+ </div>
530
+
531
+ <!-- Right Panel: GUI -->
532
+ <div class="right-panel">
533
+ <div class="gui-container">
534
+ <!-- <h2>GUI</h2> -->
535
+ <div class="gui-panel" id="gui-container">
536
+ <!-- GUI will be inserted here -->
537
+ </div>
538
+ </div>
539
+
540
+ <!-- Image Preview Area -->
541
+ <div id="image-preview-area" style="padding: 10px; display: none;">
542
+ <div style="font-size: 12px; color: #ccc; margin-bottom: 8px; text-align: left;">Input Image Preview</div>
543
+ <div style="text-align: center;">
544
+ <img id="preview-img" style="max-width: 100%; max-height: 200px; border-radius: 4px; box-shadow: 0 2px 8px rgba(0,0,0,0.3);" />
545
+ </div>
546
+ </div>
547
+ </div>
548
+ </div>
549
+
550
+ <!-- Status Bar -->
551
+ <div class="status-bar" id="status-bar">
552
+ Ready to generate 3D scenes | Cameras: 0 | Status: Waiting for input
553
+ </div>
554
+ </div>
555
+
556
+ <!-- Hidden File Inputs -->
557
+ <input id="file-input" type="file" accept=".jpg,.png,.jpeg" multiple="true" style="display: none;" />
558
+ <input id="json-input" type="file" accept=".json" multiple="false" style="display: none;" />
559
+
560
+ <script type="module">
561
+ // =========================
562
+ // Imports & Global Variables
563
+ // =========================
564
+ import * as THREE from "three";
565
+ import { SplatMesh, SparkControls, textSplats } from "@sparkjsdev/spark";
566
+ import GUI from "lil-gui";
567
+
568
+ // Scene, Camera, Renderer, Controls
569
+ const scene = new THREE.Scene();
570
+ const camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 1000);
571
+ camera.position.set(0, 0, 1.5);
572
+ const renderer = new THREE.WebGLRenderer();
573
+ renderer.setSize(window.innerWidth, window.innerHeight);
574
+
575
+ // Wait for DOM to be ready
576
+ function initializeRenderer() {
577
+ const canvasWrapper = document.getElementById('canvas-wrapper');
578
+ if (canvasWrapper) {
579
+ canvasWrapper.appendChild(renderer.domElement);
580
+
581
+ // Set initial canvas size based on current resolution
582
+ updateCanvasSize();
583
+ console.log('Canvas initialized in wrapper');
584
+ } else {
585
+ console.error('Canvas wrapper not found');
586
+ }
587
+ }
588
+
589
+ // Update canvas size based on selected resolution
590
+ function updateCanvasSize() {
591
+ const canvasWrapper = document.getElementById('canvas-wrapper');
592
+ if (!canvasWrapper) return;
593
+
594
+ // Show resizing indicator
595
+ canvasWrapper.classList.add('resizing');
596
+
597
+ // Get current resolution from GUI options
598
+ const resolution = guiOptions.Resolution.split('x');
599
+ const width = parseInt(resolution[2]) || 704; // W
600
+ const height = parseInt(resolution[1]) || 480; // H
601
+
602
+ // Set canvas size
603
+ renderer.setSize(width, height);
604
+ camera.aspect = width / height;
605
+ camera.updateProjectionMatrix();
606
+
607
+ // Update wrapper size to match canvas
608
+ canvasWrapper.style.width = width + 'px';
609
+ canvasWrapper.style.height = height + 'px';
610
+
611
+ // Remove resizing indicator after a short delay
612
+ setTimeout(() => {
613
+ canvasWrapper.classList.remove('resizing');
614
+ }, 300);
615
+
616
+ console.log('Canvas size updated:', width, 'x', height);
617
+ }
618
+
619
+ const controls = new SparkControls({ canvas: renderer.domElement });
620
+
621
+ // Camera splats and params
622
+ const cameraSplats = [];
623
+ const cameraParams = [];
624
+ const interpolatedCamerasSplats = [];
625
+
626
+ // State
627
+ let fixGenerationFOV = false;
628
+ let inputImageBase64 = null;
629
+ let inputImageResolution = null;
630
+ let currentGeneratedSplat = null; // 跟踪当前生成的场景
631
+
632
+ // UI Elements
633
+ const loadingElement = document.getElementById('loading');
634
+ const statusBar = document.getElementById('status-bar');
635
+
636
+ // GUI variable - declare early
637
+ let gui = null;
638
+
639
+ // Status update function
640
+ function updateStatus(message, cameraCount = null) {
641
+ const cameraText = cameraCount !== null ? `Cameras: ${cameraCount}` : `Cameras: ${cameraParams.length}`;
642
+ statusBar.textContent = `${message} | ${cameraText} | Status: ${fixGenerationFOV ? 'Ready to record' : 'Configure settings'}`;
643
+ }
644
+
645
+ // Show/hide loading
646
+ function showLoading(show) {
647
+ loadingElement.style.display = show ? 'block' : 'none';
648
+ }
649
+
650
+ // Show generation info
651
+ function showGenerationInfo(generationTime, fileSize) {
652
+ const generationInfo = document.getElementById('generation-info');
653
+ const generationTimeElement = document.getElementById('generation-time');
654
+ const fileSizeElement = document.getElementById('file-size');
655
+
656
+ generationTimeElement.textContent = generationTime.toFixed(2);
657
+ fileSizeElement.textContent = (fileSize / (1024 * 1024)).toFixed(2);
658
+ generationInfo.style.display = 'block';
659
+ }
660
+
661
+ // Show download progress
662
+ function showDownloadProgress() {
663
+ const downloadProgress = document.getElementById('download-progress');
664
+ downloadProgress.style.display = 'block';
665
+ const qd = document.getElementById('queue-details');
666
+ const dd = document.getElementById('download-details');
667
+ const badges = document.getElementById('status-badges');
668
+ if (qd) qd.style.display = 'none';
669
+ if (dd) dd.style.display = 'none';
670
+ if (badges) badges.style.display = 'none';
671
+ }
672
+
673
+ // Update progress bar
674
+ function updateProgressBar(percentage) {
675
+ const progressBar = document.getElementById('progress-bar');
676
+ const progressText = document.getElementById('progress-text');
677
+
678
+ progressBar.style.width = percentage + '%';
679
+ progressText.textContent = `${Math.round(percentage)}%`;
680
+ }
681
+
682
+ // Update progress label text (stage indicator)
683
+ function setProgressLabel(text) {
684
+ const progressText = document.getElementById('progress-text');
685
+ if (progressText) progressText.textContent = text;
686
+ }
687
+
688
+ // ==============
689
+ // Queue handling
690
+ // ==============
691
+ let queuePollTimer = null;
692
+ let currentTaskId = null;
693
+ let initialQueuePosition = null;
694
+ let latestGenerationTime = null;
695
+ let lastDownloadPct = 0;
696
+ let lastDownloadUpdateTs = 0;
697
+
698
+ function showQueueWaiting(position, runningCount, queuedCount) {
699
+ // Use only the progress bar to show queue progress (from initial position to 0)
700
+ showDownloadProgress();
701
+ if (initialQueuePosition === null) {
702
+ // Initialize from first seen position; ensure >= 1 so 0 -> 100%
703
+ const initPos = (typeof position === 'number') ? position : 0;
704
+ initialQueuePosition = Math.max(initPos, 1);
705
+ }
706
+ const percent = initialQueuePosition && initialQueuePosition > 0
707
+ ? Math.max(0, Math.min(100, ((initialQueuePosition - (position || 0)) / initialQueuePosition) * 100))
708
+ : 0;
709
+ updateProgressBar(percent);
710
+ const totalWaiting = (position || 0) + (queuedCount || 0);
711
+ if (position !== null && position !== undefined) {
712
+ const pctText = `${Math.round(percent)}%`;
713
+ if (totalWaiting > 0) {
714
+ setProgressLabel(`Queued ${position}/${totalWaiting} (${pctText})`);
715
+ } else {
716
+ setProgressLabel(`Queued ${position} (${pctText})`);
717
+ }
718
+ } else {
719
+ setProgressLabel('Queued');
720
+ }
721
+ }
722
+
723
+ async function pollTaskUntilReady(taskId) {
724
+ currentTaskId = taskId;
725
+ initialQueuePosition = null;
726
+ if (queuePollTimer) {
727
+ clearInterval(queuePollTimer);
728
+ queuePollTimer = null;
729
+ }
730
+ const queueStartTs = Date.now();
731
+
732
+ const pollOnce = async () => {
733
+ try {
734
+ const resp = await fetch(`${guiOptions.BackendAddress}/task/${taskId}`);
735
+ if (!resp.ok) return;
736
+ const info = await resp.json();
737
+ if (!info || !info.success) return;
738
+
739
+ const pos = info.queue && typeof info.queue.position === 'number' ? info.queue.position : 0;
740
+ const running = info.queue ? info.queue.running_count : 0;
741
+ const queued = info.queue ? info.queue.queued_count : 0;
742
+ if (info.status === 'queued' || info.status === 'running') {
743
+ // Only progress bar; set stage label
744
+ if (info.status === 'queued') {
745
+ showQueueWaiting(pos, running, queued);
746
+ } else {
747
+ // Transitioned to running: finalize queue progress visually
748
+ updateProgressBar(100);
749
+ showDownloadProgress();
750
+ setProgressLabel('Generating...');
751
+ }
752
+ }
753
+
754
+ if (info.status === 'completed' && info.download_url) {
755
+ clearInterval(queuePollTimer);
756
+ queuePollTimer = null;
757
+ latestGenerationTime = typeof info.generation_time === 'number' ? info.generation_time : null;
758
+ // Proceed to download the generated file like the normal path
759
+ updateStatus('Downloading generated scene...', cameraParams.length);
760
+ const response = await fetch(guiOptions.BackendAddress + info.download_url);
761
+ if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
762
+ const contentLength = response.headers.get('content-length');
763
+ const total = parseInt(contentLength || '0', 10);
764
+ // Show generation info immediately once we know it and total size from headers
765
+ showGenerationInfo(latestGenerationTime || 0, total);
766
+ let loaded = 0;
767
+ const reader = response.body.getReader();
768
+ const chunks = [];
769
+ updateProgressBar(0);
770
+ setProgressLabel('Downloading 0%');
771
+ lastDownloadPct = 0;
772
+ lastDownloadUpdateTs = 0;
773
+ while (true) {
774
+ const { done, value } = await reader.read();
775
+ if (done) break;
776
+ chunks.push(value);
777
+ loaded += value.length;
778
+ if (total) {
779
+ const pct = Math.min(100, (loaded / total) * 100);
780
+ const now = Date.now();
781
+ const rounded = Math.round(pct);
782
+ // Throttle and enforce monotonic increase
783
+ if (rounded > Math.round(lastDownloadPct) || (now - lastDownloadUpdateTs) > 200) {
784
+ lastDownloadPct = Math.max(lastDownloadPct, pct);
785
+ updateProgressBar(lastDownloadPct);
786
+ setProgressLabel(`Downloading ${Math.round(lastDownloadPct)}%`);
787
+ lastDownloadUpdateTs = now;
788
+ }
789
+ }
790
+ }
791
+
792
+ if (instructionSplat) {
793
+ scene.remove(instructionSplat);
794
+ console.log('Instruction splat removed');
795
+ instructionSplat = null;
796
+ }
797
+
798
+ const blob = new Blob(chunks);
799
+ const url = URL.createObjectURL(blob);
800
+ // Continue to load the splat
801
+ updateStatus('Loading generated scene...', cameraParams.length);
802
+
803
+ const GeneratedSplat = new SplatMesh({ url });
804
+ scene.add(GeneratedSplat);
805
+ currentGeneratedSplat = GeneratedSplat;
806
+ updateStatus('Scene generated successfully!', cameraParams.length);
807
+ // Show generation time and total file size (MB)
808
+ showGenerationInfo(latestGenerationTime || 0, total || blob.size);
809
+ // Notify backend to delete the server file after client has downloaded it
810
+ try {
811
+ if (info.file_id) {
812
+ const resp = await fetch(`${guiOptions.BackendAddress}/delete/${info.file_id}`, { method: 'POST' });
813
+ if (!resp.ok) console.warn('Delete notify failed');
814
+ }
815
+ } catch (e) {
816
+ console.warn('Delete notify error', e);
817
+ }
818
+ hideDownloadProgress();
819
+ showLoading(false);
820
+ } else if (info.status === 'failed') {
821
+ clearInterval(queuePollTimer);
822
+ queuePollTimer = null;
823
+ throw new Error(info.error || 'Generation failed');
824
+ }
825
+ } catch (e) {
826
+ console.debug('Polling error:', e);
827
+ }
828
+ };
829
+
830
+ await pollOnce();
831
+ queuePollTimer = setInterval(pollOnce, 2000);
832
+ }
833
+
834
+ // Hide download progress
835
+ function hideDownloadProgress() {
836
+ const downloadProgress = document.getElementById('download-progress');
837
+ downloadProgress.style.display = 'none';
838
+ }
839
+
840
+ // Playback scrubber (0..1)
841
+ let userCameraState = null; // 存储用户播放前的相机状态
842
+
843
+ // 根据时间比例获取插值相机
844
+ function getInterpolatedCameraAtTime(t) {
845
+ if (cameraParams.length === 0) {
846
+ return camera;
847
+ }
848
+
849
+ if (cameraParams.length === 1) {
850
+ return cameraParams[0];
851
+ }
852
+
853
+ // 确保t在有效范围内
854
+ const clampedT = Math.max(0, Math.min(1, t));
855
+
856
+ // 计算在相机序列中的位置
857
+ const cameraIndex = clampedT * (cameraParams.length - 1);
858
+ const startIndex = Math.min(Math.floor(cameraIndex), cameraParams.length - 2);
859
+ const endIndex = startIndex + 1;
860
+ const startCamera = cameraParams[startIndex];
861
+ const endCamera = cameraParams[endIndex];
862
+
863
+ // 计算两个相机之间的插值比例
864
+ const _t = cameraIndex - startIndex;
865
+
866
+ // 使用interpolateTwoCameras进行插值
867
+ return interpolateTwoCameras(startCamera, endCamera, _t);
868
+ }
869
+
870
+ function setCameraByScrub(t) {
871
+ if (cameraParams.length === 0) return;
872
+ const clampedT = Math.max(0, Math.min(1, t));
873
+ const camT = getInterpolatedCameraAtTime(clampedT);
874
+ camera.position.copy(camT.position);
875
+ camera.quaternion.copy(camT.quaternion);
876
+ camera.fov = camT.fov;
877
+ camera.updateProjectionMatrix();
878
+ }
879
+
880
+ // Supported resolutions
881
+ const supportedResolutions = [
882
+ { frame: 24, width: 704, height: 480 },
883
+ { frame: 24, width: 480, height: 704 }
884
+ ];
885
+
886
+ // GUI Options - declare early
887
+ const guiOptions = {
888
+ // 后端地址,默认为本页面ip
889
+ BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
890
+ FOV: 60,
891
+ LoadFromJson: () => {
892
+ const jsonInput = document.querySelector("#json-input");
893
+ if (jsonInput) jsonInput.click();
894
+ },
895
+ LoadTrajectoryFromJson: () => {
896
+ if (!fixGenerationFOV) {
897
+ updateStatus('Warning: Please fix configuration first before loading trajectory', cameraParams.length);
898
+ return;
899
+ }
900
+ // 设置标志,表示只加载轨迹
901
+ window.loadTrajectoryOnly = true;
902
+ const jsonInput = document.querySelector("#json-input");
903
+ if (jsonInput) jsonInput.click();
904
+ },
905
+ fixGenerationFOV: () => {
906
+ // These controllers will be set when GUI is initialized
907
+ if (window.fixGenerationFOVController) window.fixGenerationFOVController.disable();
908
+ fixGenerationFOV = true;
909
+
910
+ const new_camera = new THREE.PerspectiveCamera(guiOptions.FOV, guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1]);
911
+ new_camera.position.set(0, 0, 0);
912
+ new_camera.quaternion.set(0, 0, 0, 1);
913
+ new_camera.updateProjectionMatrix();
914
+
915
+ const cameraSplat = createCameraSplat(new_camera);
916
+ cameraSplats.push(cameraSplat);
917
+ cameraParams.push({
918
+ position: new_camera.position.clone(),
919
+ quaternion: new_camera.quaternion.clone(),
920
+ fov: new_camera.fov,
921
+ aspect: new_camera.aspect,
922
+ });
923
+ scene.add(cameraSplat);
924
+
925
+ updateStatus('Camera settings fixed. Press Space to record cameras.', cameraParams.length);
926
+ },
927
+ Resolution: `${supportedResolutions[0].frame}x${supportedResolutions[0].height}x${supportedResolutions[0].width}`,
928
+ VisualizeCameraSplats: true,
929
+ VisualizeInterpolatedCameras: true,
930
+ inputImagePrompt: () => {
931
+ const fileInput = document.querySelector("#file-input");
932
+ if (fileInput) {
933
+ // 仅触发选择,由全局处理程序完成裁剪与预览更新
934
+ fileInput.click();
935
+ }
936
+ },
937
+ imageIndex: 0,
938
+ inputTextPrompt: "",
939
+
940
+ // Camera trajectory templates
941
+ trajectoryMode: "Manual",
942
+ templateType: "Move Forward",
943
+ cameraTrajectory: "Manual",
944
+ trajectorySettings: {
945
+ angle: 180, // 角度 (180, 360)
946
+ tilt: 15 // 倾斜角 (15, 30, 45)
947
+ },
948
+ generateTrajectory: () => {
949
+ generateCameraTrajectory(guiOptions.templateType);
950
+ },
951
+ saveTrajectoryToJson: () => {
952
+ if (cameraParams.length === 0) {
953
+ updateStatus('No cameras to save.', cameraParams.length);
954
+ console.warn('No cameras to save');
955
+ return;
956
+ }
957
+
958
+ // Build JSON payload compatible with loader
959
+ const [nStr, hStr, wStr] = guiOptions.Resolution.split('x');
960
+ const n = parseInt(nStr), h = parseInt(hStr), w = parseInt(wStr);
961
+ const payload = {
962
+ // image_prompt: null,
963
+ // text_prompt: guiOptions.inputTextPrompt || "",
964
+ // image_index: guiOptions.imageIndex || 0,
965
+ // resolution: [n, h, w],
966
+ cameras: cameraParams.map(cam => ({
967
+ position: [cam.position.x, cam.position.y, cam.position.z],
968
+ quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z]
969
+ }))
970
+ };
971
+
972
+ const blob = new Blob([JSON.stringify(payload, null, 2)], { type: 'application/json' });
973
+ const url = URL.createObjectURL(blob);
974
+ const a = document.createElement('a');
975
+ a.href = url;
976
+ a.download = `trajectory_${Date.now()}.json`;
977
+ document.body.appendChild(a);
978
+ a.click();
979
+ document.body.removeChild(a);
980
+ URL.revokeObjectURL(url);
981
+ updateStatus('Trajectory saved to JSON.', cameraParams.length);
982
+ },
983
+ clearAllCameras: () => {
984
+ if (cameraParams.length <= 1) {
985
+ updateStatus('No cameras to clear (first camera is always preserved)', cameraParams.length);
986
+ return;
987
+ }
988
+
989
+ // Keep the first camera, remove all others
990
+ const firstCamera = cameraParams[0];
991
+ const firstSplat = cameraSplats[0];
992
+
993
+ // Remove all camera splats except the first one
994
+ for (let i = cameraSplats.length - 1; i >= 1; i--) {
995
+ scene.remove(cameraSplats[i]);
996
+ }
997
+
998
+ // Keep only the first camera in arrays
999
+ cameraSplats.length = 1;
1000
+ cameraParams.length = 1;
1001
+
1002
+ // Clear all interpolated camera splats from scene
1003
+ interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
1004
+ interpolatedCamerasSplats.length = 0;
1005
+
1006
+ updateStatus('Cameras cleared (first camera preserved). Ready to add more cameras.', 1);
1007
+ console.log('Cameras cleared, first camera preserved');
1008
+ },
1009
+ // Playback scrub value (0..1)
1010
+ playbackT: 0,
1011
+
1012
+ generate: () => {
1013
+ // 检查是否有足够的相机
1014
+ if (cameraParams.length < 2) {
1015
+ console.error('Need at least 2 cameras to generate. Please press Space to record more cameras.');
1016
+ updateStatus('Error: Need at least 2 cameras', cameraParams.length);
1017
+ return;
1018
+ }
1019
+
1020
+ updateStatus('Preparing generation...', cameraParams.length);
1021
+
1022
+ // 删除之前生成的场景
1023
+ if (currentGeneratedSplat) {
1024
+ scene.remove(currentGeneratedSplat);
1025
+ currentGeneratedSplat = null;
1026
+ console.log('Previous generated scene removed');
1027
+ }
1028
+
1029
+ // 初始化进度条信息
1030
+ const generationTimeElement = document.getElementById('generation-time');
1031
+ const fileSizeElement = document.getElementById('file-size');
1032
+ const progressBar = document.getElementById('progress-bar');
1033
+ const progressText = document.getElementById('progress-text');
1034
+
1035
+ if (generationTimeElement) generationTimeElement.textContent = '-';
1036
+ if (fileSizeElement) fileSizeElement.textContent = '-';
1037
+ if (progressBar) progressBar.style.width = '0%';
1038
+ if (progressText) progressText.textContent = '0%';
1039
+
1040
+ // 隐藏生成信息和下载进度
1041
+ const generationInfo = document.getElementById('generation-info');
1042
+ const downloadProgress = document.getElementById('download-progress');
1043
+ if (generationInfo) generationInfo.style.display = 'none';
1044
+ if (downloadProgress) downloadProgress.style.display = 'none';
1045
+
1046
+ showLoading(true);
1047
+
1048
+ // 生成插值相机并可视化
1049
+ const interpolatedCameras = interpolateCameras(cameraParams, parseInt(guiOptions.Resolution.split('x')[0]));
1050
+ interpolatedCameras.forEach(cam => {
1051
+ const interpolatedCameraSplat = createCameraSplat(cam, [0.5, 0.5, 0.5]);
1052
+ interpolatedCamerasSplats.push(interpolatedCameraSplat);
1053
+ scene.add(interpolatedCameraSplat);
1054
+ });
1055
+
1056
+ console.log('Sending request to backend...');
1057
+ console.log('Interpolated cameras:', interpolatedCameras.length);
1058
+ updateStatus('Sending request to backend...', cameraParams.length);
1059
+
1060
+ // 根据后端类型选择不同的请求方式
1061
+ let requestUrl, requestBody;
1062
+
1063
+ if (true) {
1064
+ // Flask后端:直接POST到/generate
1065
+ requestUrl = guiOptions.BackendAddress + '/generate';
1066
+ requestBody = JSON.stringify({
1067
+ image_prompt: inputImageBase64 ? inputImageBase64 : "",
1068
+ text_prompt: guiOptions.inputTextPrompt,
1069
+ image_index: 0,
1070
+ resolution: [
1071
+ parseInt(guiOptions.Resolution.split('x')[0]),
1072
+ parseInt(guiOptions.Resolution.split('x')[1]),
1073
+ parseInt(guiOptions.Resolution.split('x')[2])
1074
+ ],
1075
+ cameras: interpolatedCameras.map(cam => ({
1076
+ position: [cam.position.x, cam.position.y, cam.position.z],
1077
+ quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z],
1078
+ fx: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
1079
+ fy: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
1080
+ cx: inputImageBase64 && inputImageResolution
1081
+ ? 0.5 * inputImageResolution.width
1082
+ : 0.5 * parseInt(guiOptions.Resolution.split('x')[2]),
1083
+ cy: inputImageBase64 && inputImageResolution
1084
+ ? 0.5 * inputImageResolution.height
1085
+ : 0.5 * parseInt(guiOptions.Resolution.split('x')[1]),
1086
+ }))
1087
+ });
1088
+ } else {
1089
+
1090
+ }
1091
+
1092
+ // 请求后端生成(异步:返回task_id并开始排队轮询)
1093
+ fetch(requestUrl, {
1094
+ method: 'POST',
1095
+ headers: { 'Content-Type': 'application/json' },
1096
+ mode: 'cors',
1097
+ body: requestBody
1098
+ })
1099
+ .then(response => {
1100
+ const contentType = response.headers.get('content-type');
1101
+ if (contentType && contentType.includes('application/json')) {
1102
+ return response.json();
1103
+ } else {
1104
+ return response.blob().then(blob => {
1105
+ const url = URL.createObjectURL(blob);
1106
+ return { url };
1107
+ });
1108
+ }
1109
+ })
1110
+ .then(data => {
1111
+ console.log(data);
1112
+ {
1113
+ // 异步队列协议:后端返回 task_id + queue 信息(202)
1114
+ if (data && data.success && data.task_id) {
1115
+ updateStatus('Queued request submitted. Waiting in queue...', cameraParams.length);
1116
+ showQueueWaiting(data.queue?.position || 0, data.queue?.running_count || 0, data.queue?.queued_count || 0);
1117
+ // 轮询直到任务完成并下载
1118
+ return pollTaskUntilReady(data.task_id).then(() => ({ url: null }));
1119
+ }
1120
+ // 兼容旧的直接文件响应格式
1121
+ if (data && data.url) {
1122
+ updateStatus('Loading generated scene...', cameraParams.length);
1123
+ return Promise.resolve(data);
1124
+ }
1125
+ throw new Error('Invalid Flask response (expected task_id)');
1126
+ }
1127
+ })
1128
+ .then(data => {
1129
+ if (data.url) {
1130
+ updateStatus('Loading 3D scene...', cameraParams.length);
1131
+ // Remove the instruction splat when generation is complete
1132
+ if (instructionSplat) {
1133
+ scene.remove(instructionSplat);
1134
+ console.log('Instruction splat removed');
1135
+ }
1136
+ const GeneratedSplat = new SplatMesh({ url: data.url });
1137
+ scene.add(GeneratedSplat);
1138
+ currentGeneratedSplat = GeneratedSplat; // 保存新生成的场景引用
1139
+ console.log('3D scene loaded successfully!');
1140
+ updateStatus('Scene generated successfully!', cameraParams.length);
1141
+ hideDownloadProgress();
1142
+ showLoading(false);
1143
+ }
1144
+ })
1145
+ .catch(error => {
1146
+ console.error('Error:', error);
1147
+ updateStatus('Generation failed: ' + error.message, cameraParams.length);
1148
+ hideDownloadProgress();
1149
+ showLoading(false);
1150
+ });
1151
+ }
1152
+ };
1153
+
1154
+ // Initialize renderer and GUI when DOM is ready
1155
+ function initializeApp() {
1156
+ try {
1157
+ // Debug layout
1158
+ console.log('Initializing app...');
1159
+ console.log('Center panel:', document.querySelector('.center-panel'));
1160
+ console.log('GUI container:', document.getElementById('gui-container'));
1161
+ console.log('Right panel:', document.querySelector('.right-panel'));
1162
+
1163
+ initializeRenderer();
1164
+ initializeGUI();
1165
+ console.log('App initialization complete');
1166
+ } catch (error) {
1167
+ console.error('App initialization failed:', error);
1168
+ }
1169
+ }
1170
+
1171
+ if (document.readyState === 'loading') {
1172
+ document.addEventListener('DOMContentLoaded', initializeApp);
1173
+ } else {
1174
+ initializeApp();
1175
+ }
1176
+
1177
+ // =========================
1178
+ // Utility & Core Functions
1179
+ // =========================
1180
+
1181
+ // 计算插值相机
1182
+ function interpolateTwoCameras(startCamera, endCamera, _t) {
1183
+ const interpolatedCamera = new THREE.PerspectiveCamera(startCamera.fov, startCamera.aspect);
1184
+
1185
+ // 如果_t接近0,直接使用startCamera
1186
+ if (_t < 1e-6) {
1187
+ interpolatedCamera.position.copy(startCamera.position);
1188
+ interpolatedCamera.quaternion.copy(startCamera.quaternion);
1189
+ }
1190
+ // 如果_t接近1,直接使用endCamera
1191
+ else if (_t > 1 - 1e-6) {
1192
+ interpolatedCamera.position.copy(endCamera.position);
1193
+ interpolatedCamera.quaternion.copy(endCamera.quaternion);
1194
+ }
1195
+ // 否则进行插值
1196
+ else {
1197
+ interpolatedCamera.position.copy(startCamera.position).lerp(endCamera.position, _t);
1198
+ interpolatedCamera.quaternion.copy(startCamera.quaternion).slerp(endCamera.quaternion, _t);
1199
+ }
1200
+
1201
+ return interpolatedCamera;
1202
+ }
1203
+
1204
+ function interpolateCameras(cameras, M) {
1205
+ const interpolatedCameras = [];
1206
+
1207
+ if (cameras.length === 0) {
1208
+ return interpolatedCameras;
1209
+ }
1210
+
1211
+ if (cameras.length === 1) {
1212
+ // 如果只有一个相机,重复使用它
1213
+ for (let i = 0; i < M; i++) {
1214
+ interpolatedCameras.push(cameras[0]);
1215
+ }
1216
+ return interpolatedCameras;
1217
+ }
1218
+
1219
+ for (let i = 0; i < M; i++) {
1220
+ const t = i / (M - 1);
1221
+ const startIndex = Math.min(Math.floor(t * (cameras.length - 1)), cameras.length - 2);
1222
+ const endIndex = startIndex + 1;
1223
+ const startCamera = cameras[startIndex];
1224
+ const endCamera = cameras[endIndex];
1225
+ const _t = t * (cameras.length - 1) - startIndex;
1226
+ const interpolatedCamera = interpolateTwoCameras(startCamera, endCamera, _t);
1227
+ interpolatedCameras.push(interpolatedCamera);
1228
+ }
1229
+ return interpolatedCameras;
1230
+ }
1231
+
1232
+ // 创建立方体的splat可视化
1233
+ function createCubeSplat(size = 0.1, pointColor = [1, 1, 1]) {
1234
+ const cubeSplat = new SplatMesh({
1235
+ constructSplats: (splats) => {
1236
+ const NUM_SPLATS_PER_EDGE = 1000;
1237
+ const scales = new THREE.Vector3().setScalar(0.002);
1238
+ const quaternion = new THREE.Quaternion();
1239
+ const opacity = 1;
1240
+ const color = new THREE.Color(...pointColor);
1241
+
1242
+ // 立方体的8个顶点
1243
+ const halfSize = size / 2;
1244
+ const vertices = [
1245
+ new THREE.Vector3(-halfSize, -halfSize, -halfSize), // 0: 左下后
1246
+ new THREE.Vector3(halfSize, -halfSize, -halfSize), // 1: 右下后
1247
+ new THREE.Vector3(halfSize, halfSize, -halfSize), // 2: 右上后
1248
+ new THREE.Vector3(-halfSize, halfSize, -halfSize), // 3: 左上后
1249
+ new THREE.Vector3(-halfSize, -halfSize, halfSize), // 4: 左下前
1250
+ new THREE.Vector3(halfSize, -halfSize, halfSize), // 5: 右下前
1251
+ new THREE.Vector3(halfSize, halfSize, halfSize), // 6: 右上前
1252
+ new THREE.Vector3(-halfSize, halfSize, halfSize), // 7: 左上前
1253
+ ];
1254
+
1255
+ // 立方体的12条边
1256
+ const edges = [
1257
+ [0, 1], [1, 2], [2, 3], [3, 0], // 后面4条边
1258
+ [4, 5], [5, 6], [6, 7], [7, 4], // 前面4条边
1259
+ [0, 4], [1, 5], [2, 6], [3, 7], // 连接前后4条边
1260
+ ];
1261
+
1262
+ // 为每条边生成splat点
1263
+ for (let i = 0; i < edges.length; i++) {
1264
+ const start = vertices[edges[i][0]];
1265
+ const end = vertices[edges[i][1]];
1266
+ for (let j = 0; j < NUM_SPLATS_PER_EDGE; j++) {
1267
+ const point = new THREE.Vector3().lerpVectors(start, end, j / NUM_SPLATS_PER_EDGE);
1268
+ splats.pushSplat(point, scales, quaternion, opacity, color);
1269
+ }
1270
+ }
1271
+ },
1272
+ });
1273
+ return cubeSplat;
1274
+ }
1275
+
1276
+ // 创建相机锥体的splat可视化
1277
+ function createCameraSplat(camera, pointColor = [1, 1, 1]) {
1278
+ const cameraSplat = new SplatMesh({
1279
+ constructSplats: (splats) => {
1280
+ const NUM_SPLATS_PER_EDGE = 1000;
1281
+ const LENGTH_PER_EDGE = 0.1;
1282
+ const center = new THREE.Vector3();
1283
+ const scales = new THREE.Vector3().setScalar(0.001);
1284
+ const quaternion = new THREE.Quaternion();
1285
+ const opacity = 1;
1286
+ const color = new THREE.Color(...pointColor);
1287
+
1288
+ const H = 1000;
1289
+ const W = 1000 * camera.aspect;
1290
+ const fx = 0.5 * H / Math.tan(0.5 * camera.fov * Math.PI / 180);
1291
+ const fy = 0.5 * H / Math.tan(0.5 * camera.fov * Math.PI / 180);
1292
+
1293
+ const xt = (0 - W / 2 + 0.5) / fy;
1294
+ const xb = (W - W / 2 + 0.5) / fy;
1295
+ const yl = - (0 - H / 2 + 0.5) / fx;
1296
+ const yr = - (H - H / 2 + 0.5) / fx;
1297
+
1298
+ const lt = new THREE.Vector3(xt * LENGTH_PER_EDGE, yl * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
1299
+ const rt = new THREE.Vector3(xt * LENGTH_PER_EDGE, yr * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
1300
+ const lb = new THREE.Vector3(xb * LENGTH_PER_EDGE, yl * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
1301
+ const rb = new THREE.Vector3(xb * LENGTH_PER_EDGE, yr * LENGTH_PER_EDGE, -1 * LENGTH_PER_EDGE);
1302
+
1303
+ const lines = [
1304
+ [center, lt], [center, rt], [center, lb], [center, rb],
1305
+ [lt, rt], [lt, lb], [rt, rb], [lb, rb],
1306
+ ];
1307
+
1308
+ for (let i = 0; i < lines.length; i++) {
1309
+ for (let j = 0; j < NUM_SPLATS_PER_EDGE; j++) {
1310
+ const point = new THREE.Vector3().lerpVectors(lines[i][0], lines[i][1], j / NUM_SPLATS_PER_EDGE);
1311
+ splats.pushSplat(point, scales, quaternion, opacity, color);
1312
+ }
1313
+ }
1314
+ },
1315
+ });
1316
+ cameraSplat.quaternion.copy(camera.quaternion);
1317
+ cameraSplat.position.copy(camera.position);
1318
+ return cameraSplat;
1319
+ }
1320
+
1321
+ // 生成相机轨迹模板
1322
+ function generateCameraTrajectory(trajectoryType) {
1323
+ if (trajectoryType === "Manual") {
1324
+ updateStatus('Manual mode: Use Space to record cameras manually', cameraParams.length);
1325
+ return;
1326
+ }
1327
+
1328
+ // 检查FOV是否已固定
1329
+ if (!fixGenerationFOV) {
1330
+ updateStatus('Error: Please fix FOV first before generating trajectory', cameraParams.length);
1331
+ return;
1332
+ }
1333
+
1334
+ // 获取最后一个相机作为参考点
1335
+ let referenceCamera;
1336
+ if (cameraParams.length > 0) {
1337
+ // 使用最后一个已保存的相机作为参考
1338
+ const lastCamera = cameraParams[cameraParams.length - 1];
1339
+ referenceCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
1340
+ referenceCamera.position.copy(lastCamera.position);
1341
+ referenceCamera.quaternion.copy(lastCamera.quaternion);
1342
+ referenceCamera.updateProjectionMatrix();
1343
+ } else {
1344
+ // 如果没有已保存的相机,从原点开始
1345
+ referenceCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
1346
+ referenceCamera.position.set(0, 0, 0);
1347
+ referenceCamera.quaternion.set(0, 0, 0, 1);
1348
+ referenceCamera.updateProjectionMatrix();
1349
+ }
1350
+
1351
+ // 对于orbit,计算所有相机围绕的目标点
1352
+ // 始终使用当前参考相机(最后一个相机)来计算目标点
1353
+ let orbitTarget = null;
1354
+ let orbitStartCamera = null;
1355
+ if (trajectoryType.includes("Orbit") && cameraParams.length > 0) {
1356
+ // 使用最后一个相机作为参考,计算其前方1单位的目标点
1357
+ orbitStartCamera = cameraParams[cameraParams.length - 1];
1358
+ orbitTarget = orbitStartCamera.position.clone().add(
1359
+ new THREE.Vector3(0, 0, -1).applyQuaternion(orbitStartCamera.quaternion)
1360
+ );
1361
+ console.log("Orbit target calculated from last camera:", orbitStartCamera.position, "->", orbitTarget);
1362
+ } else if (trajectoryType.includes("Orbit")) {
1363
+ // 如果没有已记录的相机,使用当前相机作为参考
1364
+ orbitStartCamera = referenceCamera;
1365
+ orbitTarget = referenceCamera.position.clone().add(
1366
+ new THREE.Vector3(0, 0, -1).applyQuaternion(referenceCamera.quaternion)
1367
+ );
1368
+ console.log("Orbit target calculated from current camera:", referenceCamera.position, "->", orbitTarget);
1369
+ }
1370
+
1371
+ const cameras = [];
1372
+ const stepSize = 0.5; // 移动步长
1373
+ const totalOrbitAngle = 15 * Math.PI / 180; // 总共15度轨道
1374
+
1375
+ // 根据轨迹类型生成相机
1376
+ let numCameras = 1; // 默认生成1个相机
1377
+ if (trajectoryType.includes("Orbit")) {
1378
+ numCameras = 1; // 轨道运动生成1个相机
1379
+ console.log(`Generating ${numCameras} orbit camera with total angle ${totalOrbitAngle * 180 / Math.PI}°`);
1380
+ }
1381
+
1382
+ for (let i = 1; i <= numCameras; i++) {
1383
+ const newCamera = new THREE.PerspectiveCamera(guiOptions.FOV, camera.aspect);
1384
+ let position, quaternion;
1385
+
1386
+ switch (trajectoryType) {
1387
+ case "Move Forward":
1388
+ position = referenceCamera.position.clone();
1389
+ position.z -= stepSize;
1390
+ quaternion = referenceCamera.quaternion.clone();
1391
+ break;
1392
+
1393
+ case "Move Backward":
1394
+ position = referenceCamera.position.clone();
1395
+ position.z += stepSize;
1396
+ quaternion = referenceCamera.quaternion.clone();
1397
+ break;
1398
+
1399
+ case "Move Left":
1400
+ position = referenceCamera.position.clone();
1401
+ position.x -= stepSize;
1402
+ quaternion = referenceCamera.quaternion.clone();
1403
+ break;
1404
+
1405
+ case "Move Right":
1406
+ position = referenceCamera.position.clone();
1407
+ position.x += stepSize;
1408
+ quaternion = referenceCamera.quaternion.clone();
1409
+ break;
1410
+
1411
+ case "Orbit Left 15°":
1412
+ const radius = 1.0;
1413
+ // 左轨道:-15度
1414
+ const angle = -totalOrbitAngle;
1415
+
1416
+ console.log(`Camera ${i}: angle=${angle * 180 / Math.PI}° (Left)`);
1417
+
1418
+ // 计算轨道位置:在参考相机的局部坐标系中
1419
+ const localOrbitPos = new THREE.Vector3(
1420
+ Math.sin(angle) * radius,
1421
+ 0,
1422
+ Math.cos(angle) * radius
1423
+ );
1424
+
1425
+ // 转换到世界坐标系:旋转到参考相机的方向
1426
+ const worldOrbitPos = localOrbitPos.applyQuaternion(orbitStartCamera.quaternion);
1427
+
1428
+ // 最终位置:从目标点出发,加上世界坐标系中的偏移
1429
+ position = orbitTarget.clone().add(worldOrbitPos);
1430
+
1431
+ console.log(`Orbit Left camera ${i}: localPos=`, localOrbitPos, 'worldPos=', worldOrbitPos, 'finalPos=', position);
1432
+
1433
+ // 朝向:所有相机都朝向圆心(目标点)
1434
+ const lookDirection = orbitTarget.clone().sub(position).normalize();
1435
+ quaternion = new THREE.Quaternion().setFromUnitVectors(
1436
+ new THREE.Vector3(0, 0, -1),
1437
+ lookDirection
1438
+ );
1439
+
1440
+ console.log(`Orbit Left camera ${i}: quaternion=`, quaternion);
1441
+ break;
1442
+
1443
+ case "Orbit Right 15°":
1444
+ const radiusRight = 1.0;
1445
+ // 右轨道:+15度
1446
+ const angleRight = totalOrbitAngle;
1447
+
1448
+ console.log(`Camera ${i}: angle=${angleRight * 180 / Math.PI}° (Right)`);
1449
+
1450
+ // 计算轨道位置:在参考相机的局部坐标系中
1451
+ const localOrbitPosRight = new THREE.Vector3(
1452
+ Math.sin(angleRight) * radiusRight,
1453
+ 0,
1454
+ Math.cos(angleRight) * radiusRight
1455
+ );
1456
+
1457
+ // 转换到世界坐标系:旋转到参考相机的方向
1458
+ const worldOrbitPosRight = localOrbitPosRight.applyQuaternion(orbitStartCamera.quaternion);
1459
+
1460
+ // 最终位置:从目标点出发,加上世界坐标系中的偏移
1461
+ position = orbitTarget.clone().add(worldOrbitPosRight);
1462
+
1463
+ console.log(`Orbit Right camera ${i}: localPos=`, localOrbitPosRight, 'worldPos=', worldOrbitPosRight, 'finalPos=', position);
1464
+
1465
+ // 朝向:所有相机都朝向圆心(目标点)
1466
+ const lookDirectionRight = orbitTarget.clone().sub(position).normalize();
1467
+ quaternion = new THREE.Quaternion().setFromUnitVectors(
1468
+ new THREE.Vector3(0, 0, -1),
1469
+ lookDirectionRight
1470
+ );
1471
+
1472
+ console.log(`Orbit Right camera ${i}: quaternion=`, quaternion);
1473
+ break;
1474
+
1475
+
1476
+ default:
1477
+ position = referenceCamera.position.clone();
1478
+ quaternion = referenceCamera.quaternion.clone();
1479
+ }
1480
+
1481
+ newCamera.position.copy(position);
1482
+ newCamera.quaternion.copy(quaternion);
1483
+ newCamera.updateProjectionMatrix();
1484
+ cameras.push(newCamera);
1485
+ }
1486
+
1487
+ // 添加相机到场景
1488
+ cameras.forEach(cam => {
1489
+ const cameraSplat = createCameraSplat(cam);
1490
+ cameraSplats.push(cameraSplat);
1491
+ cameraParams.push({
1492
+ position: cam.position.clone(),
1493
+ quaternion: cam.quaternion.clone(),
1494
+ fov: cam.fov,
1495
+ aspect: cam.aspect,
1496
+ });
1497
+ scene.add(cameraSplat);
1498
+ });
1499
+
1500
+ updateStatus(`Added ${cameras.length} cameras using ${trajectoryType} trajectory`, cameraParams.length);
1501
+ console.log(`Added ${cameras.length} cameras using ${trajectoryType} trajectory`);
1502
+ }
1503
+
1504
+ // =========================
1505
+ // GUI & User Interaction
1506
+ // =========================
1507
+
1508
+ // GUI 控件 - 延迟初始化
1509
+ function initializeGUI() {
1510
+ const guiContainer = document.getElementById('gui-container');
1511
+ if (guiContainer && !gui) {
1512
+ // Clear any existing content
1513
+ guiContainer.innerHTML = '';
1514
+
1515
+ gui = new GUI({ title: "FlashWorld Controls", container: guiContainer });
1516
+ console.log('GUI initialized in container:', guiContainer);
1517
+
1518
+ // Step 1: Configure Generation Settings
1519
+ const step1Folder = gui.addFolder('1. Configure Settings');
1520
+ step1Folder.add(guiOptions, "BackendAddress").name("Backend Address");
1521
+
1522
+ // FOV和Resolution控制器,初始时启用
1523
+ const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {
1524
+ camera.fov = value;
1525
+ camera.updateProjectionMatrix();
1526
+ });
1527
+ const resolutionController = step1Folder.add(guiOptions, "Resolution", supportedResolutions.map(
1528
+ r => `${r.frame}x${r.height}x${r.width}`
1529
+ )).name("Resolution (NxHxW)").onChange((value) => {
1530
+ updateCanvasSize();
1531
+ });
1532
+
1533
+ // Fix Configuration按钮放在最下面
1534
+ const fixGenerationFOVController = step1Folder.add(guiOptions, "fixGenerationFOV").name("Fix Configuration");
1535
+ step1Folder.open();
1536
+
1537
+ // Step 2: Set Up Camera Path
1538
+ const step2Folder = gui.addFolder('2. Set Up Camera Path');
1539
+
1540
+ // Camera trajectory templates
1541
+ const trajectoryFolder = step2Folder.addFolder('Camera Trajectory');
1542
+
1543
+ // 轨迹模式选择
1544
+ const trajectoryModeController = trajectoryFolder.add(guiOptions, "trajectoryMode", [
1545
+ "Manual",
1546
+ "Template",
1547
+ "JSON"
1548
+ ]).name("Trajectory Mode");
1549
+
1550
+ // 模板类型选择(仅在Template模式下可用)
1551
+ const templateTypeController = trajectoryFolder.add(guiOptions, "templateType", [
1552
+ "Move Forward",
1553
+ "Move Backward",
1554
+ "Move Left",
1555
+ "Move Right",
1556
+ "Orbit Left 15°",
1557
+ "Orbit Right 15°"
1558
+ ]).name("Template Type");
1559
+
1560
+ // 生成轨迹按钮
1561
+ const generateTrajectoryController = trajectoryFolder.add(guiOptions, "generateTrajectory").name("Generate Trajectory");
1562
+
1563
+ // 加载/保存JSON轨迹按钮
1564
+ const loadTrajectoryController = trajectoryFolder.add(guiOptions, "LoadTrajectoryFromJson").name("Load from JSON");
1565
+ const saveTrajectoryController = trajectoryFolder.add(guiOptions, "saveTrajectoryToJson").name("Save Trajectory");
1566
+
1567
+ // 清理相机按钮
1568
+ const clearAllCamerasController = trajectoryFolder.add(guiOptions, "clearAllCameras").name("Clear All Cameras");
1569
+
1570
+ // 初始状态:禁用所有轨迹相关控件
1571
+ templateTypeController.disable();
1572
+ generateTrajectoryController.disable();
1573
+ loadTrajectoryController.disable();
1574
+
1575
+ // 轨迹模式变化时的处理
1576
+ trajectoryModeController.onChange((value) => {
1577
+ if (value === "Manual") {
1578
+ templateTypeController.disable();
1579
+ generateTrajectoryController.disable();
1580
+ loadTrajectoryController.disable();
1581
+ } else if (value === "Template") {
1582
+ templateTypeController.enable();
1583
+ if (fixGenerationFOV) {
1584
+ generateTrajectoryController.enable();
1585
+ } else {
1586
+ generateTrajectoryController.disable();
1587
+ }
1588
+ loadTrajectoryController.disable();
1589
+ } else if (value === "JSON") {
1590
+ templateTypeController.disable();
1591
+ generateTrajectoryController.disable();
1592
+ if (fixGenerationFOV) {
1593
+ loadTrajectoryController.enable();
1594
+ } else {
1595
+ loadTrajectoryController.disable();
1596
+ }
1597
+ }
1598
+ });
1599
+
1600
+ // 当Configuration固定时启用轨迹生成
1601
+ const originalFixFOV = guiOptions.fixGenerationFOV;
1602
+ guiOptions.fixGenerationFOV = () => {
1603
+ originalFixFOV();
1604
+
1605
+ // Fix Configuration后禁用所有Step 1的控制器
1606
+ fovController.disable();
1607
+ resolutionController.disable();
1608
+
1609
+ // 根据当前轨迹模式启用相应控件
1610
+ if (guiOptions.trajectoryMode === "Template") {
1611
+ generateTrajectoryController.enable();
1612
+ } else if (guiOptions.trajectoryMode === "JSON") {
1613
+ loadTrajectoryController.enable();
1614
+ }
1615
+ updateStatus('Configuration fixed. You can now generate camera trajectory.', cameraParams.length);
1616
+ };
1617
+
1618
+ trajectoryFolder.open();
1619
+
1620
+ step2Folder.add(guiOptions, "VisualizeCameraSplats").name("Visualize Cameras").onChange((value) => {
1621
+ cameraSplats.forEach(cameraSplat => {
1622
+ cameraSplat.opacity = value ? 1 : 0;
1623
+ });
1624
+ });
1625
+ step2Folder.add(guiOptions, "VisualizeInterpolatedCameras").name("Visualize Interpolated Cameras").onChange((value) => {
1626
+ interpolatedCamerasSplats.forEach(interpolatedCameraSplat => {
1627
+ interpolatedCameraSplat.opacity = value ? 1 : 0;
1628
+ });
1629
+ });
1630
+
1631
+ // Store controllers globally so they can be accessed from guiOptions
1632
+ window.fixGenerationFOVController = fixGenerationFOVController;
1633
+
1634
+ // Step 3: Add Scene Prompts
1635
+ const step3Folder = gui.addFolder('3. Add Scene Prompts');
1636
+ step3Folder.add(guiOptions, "inputImagePrompt").name("Input Image Prompt");
1637
+ step3Folder.add(guiOptions, "inputTextPrompt").name("Input Text Prompt");
1638
+ step3Folder.add(guiOptions, "imageIndex", 0, 24, 1).name("Image Index");
1639
+
1640
+
1641
+ // Step 4: Generate Your Scene
1642
+ const step4Folder = gui.addFolder('4. Generate Scene');
1643
+ step4Folder.add(guiOptions, "generate").name("Generate!");
1644
+ step4Folder.open();
1645
+
1646
+ // Step 5: Trajectory Playback (Scrubber)
1647
+ const step5Folder = gui.addFolder('5. Trajectory Playback');
1648
+ step5Folder.add(guiOptions, 'playbackT', 0, 1, 0.001).name('Scrub (0-1)').onChange((value) => {
1649
+ // 首次拖动时记录用户相机状态,便于需要时恢复(可选)
1650
+ if (!userCameraState) {
1651
+ userCameraState = {
1652
+ position: camera.position.clone(),
1653
+ quaternion: camera.quaternion.clone(),
1654
+ fov: camera.fov
1655
+ };
1656
+ }
1657
+ setCameraByScrub(value);
1658
+ updateStatus(`Scrubbing trajectory: t=${value.toFixed(3)}`, cameraParams.length);
1659
+ });
1660
+ step5Folder.open();
1661
+
1662
+ }
1663
+ }
1664
+
1665
+
1666
+ // =========================
1667
+ // File Input (Image Prompt)
1668
+ // =========================
1669
+ const fileInput = document.querySelector("#file-input");
1670
+ fileInput.onchange = (event) => {
1671
+ const files = event.target.files;
1672
+ if (!files || files.length === 0) return;
1673
+ Array.from(files).forEach(file => {
1674
+ const reader = new FileReader();
1675
+ reader.onload = function(e) {
1676
+ console.log("Loaded image:", file.name, e.target.result);
1677
+
1678
+ // 获取当前Resolution
1679
+ let resolutionStr = guiOptions.Resolution;
1680
+ let [n, h, w] = resolutionStr.split('x').map(Number);
1681
+
1682
+ // 加载图片
1683
+ const img = new Image();
1684
+ img.onload = function() {
1685
+ window.inputImageResolution = { width: img.width, height: img.height };
1686
+ console.log("Input image resolution:", window.inputImageResolution);
1687
+
1688
+ // 计算center crop参数
1689
+ let scaleH = h / img.height;
1690
+ let scaleW = w / img.width;
1691
+ let scale = Math.max(scaleH, scaleW);
1692
+ let newW = Math.round(w / scale);
1693
+ let newH = Math.round(h / scale);
1694
+ let sx = Math.floor((img.width - newW) / 2);
1695
+ let sy = Math.floor((img.height - newH) / 2);
1696
+
1697
+ // 创建canvas进行center crop和resize
1698
+ const canvas = document.createElement('canvas');
1699
+ canvas.width = w;
1700
+ canvas.height = h;
1701
+ const ctx = canvas.getContext('2d');
1702
+ ctx.drawImage(
1703
+ img,
1704
+ sx, sy, newW, newH, // source crop
1705
+ 0, 0, w, h // destination size
1706
+ );
1707
+ // 得到裁剪+缩放后的base64(用于后端)
1708
+ inputImageBase64 = canvas.toDataURL('image/png');
1709
+ // 更新预览为裁剪后的图
1710
+ const previewArea = document.getElementById('image-preview-area');
1711
+ const previewImg = document.getElementById('preview-img');
1712
+ if (previewImg && previewArea) {
1713
+ previewImg.src = inputImageBase64;
1714
+ previewArea.style.display = 'block';
1715
+ }
1716
+ // 记录传给后端的分辨率(已对齐为当前Resolution)
1717
+ window.inputImageResolution = { width: w, height: h };
1718
+ console.log("Cropped and resized image to:", w, h);
1719
+ };
1720
+ img.src = e.target.result;
1721
+ };
1722
+ reader.readAsDataURL(file);
1723
+ });
1724
+
1725
+ };
1726
+
1727
+ // =========================
1728
+ // File Input (JSON)
1729
+ // =========================
1730
+ // const jsonInput = document.querySelector("#json-input");
1731
+ // jsonInput.onchange = (event) => {
1732
+ // const files = event.target.files;
1733
+ // if (!files || files.length === 0) return;
1734
+ // const file = files[0];
1735
+ // const reader = new FileReader();
1736
+ // reader.onload = function(e) {
1737
+ // let jsonData;
1738
+ // try {
1739
+ // jsonData = JSON.parse(e.target.result);
1740
+ // } catch (error) {
1741
+ // alert("JSON parsing error: " + error);
1742
+ // console.error("JSON parsing error:", error);
1743
+ // return;
1744
+ // }
1745
+
1746
+ // // 清理所有已有的相机和插值相机
1747
+ // cameraSplats.forEach(splat => scene.remove(splat));
1748
+ // cameraSplats.length = 0;
1749
+ // cameraParams.length = 0;
1750
+ // interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
1751
+ // interpolatedCamerasSplats.length = 0;
1752
+
1753
+ // try {
1754
+ // // 兼容不同命名的字段
1755
+ // const imagePrompt = jsonData.image_prompt || jsonData.imagePrompt || null;
1756
+ // const textPrompt = jsonData.text_prompt || jsonData.textPrompt || "";
1757
+ // const cameras = jsonData.cameras || [];
1758
+ // const resolution = jsonData.resolution || [16, 480, 640];
1759
+ // const imageIndex = jsonData.image_index || jsonData.imageIndex || 0;
1760
+
1761
+ // console.log("Loaded JSON data:", {
1762
+ // imagePrompt,
1763
+ // textPrompt,
1764
+ // cameras: cameras.length,
1765
+ // resolution,
1766
+ // imageIndex
1767
+ // });
1768
+
1769
+ // // 处理图像提示
1770
+ // if (imagePrompt) {
1771
+ // inputImageBase64 = imagePrompt;
1772
+ // console.log("Image prompt loaded");
1773
+ // }
1774
+
1775
+ // // 设置文本提示
1776
+ // guiOptions.inputTextPrompt = textPrompt;
1777
+ // guiOptions.imageIndex = imageIndex;
1778
+
1779
+ // // 处理相机数据
1780
+ // if (cameras && cameras.length > 0) {
1781
+ // cameras.forEach(cameraData => {
1782
+ // // 解析分辨率
1783
+ // let aspect = 1.0;
1784
+ // if (Array.isArray(resolution) && resolution.length === 3) {
1785
+ // aspect = resolution[2] / resolution[1];
1786
+ // }
1787
+ // const cam = new THREE.PerspectiveCamera(60, aspect);
1788
+
1789
+ // // 设置位置
1790
+ // if (Array.isArray(cameraData.position) && cameraData.position.length === 3) {
1791
+ // cam.position.set(cameraData.position[0], cameraData.position[1], cameraData.position[2]);
1792
+ // }
1793
+
1794
+ // // 设置四元数
1795
+ // if (Array.isArray(cameraData.quaternion) && cameraData.quaternion.length === 4) {
1796
+ // // 注意:three.js的顺序是 (x, y, z, w)
1797
+ // cam.quaternion.set(
1798
+ // cameraData.quaternion[1],
1799
+ // cameraData.quaternion[2],
1800
+ // cameraData.quaternion[3],
1801
+ // cameraData.quaternion[0]
1802
+ // );
1803
+ // }
1804
+
1805
+ // // 设置FOV和焦距
1806
+ // if (cameraData.fx && cameraData.fy) {
1807
+ // // fx, fy: 焦距(像素)
1808
+ // // 假设分辨率为 [N, H, W]
1809
+ // // fov = 2 * atan(0.5 * H / fy) * 180 / PI
1810
+ // // 但原代码用的是 fx
1811
+ // let fov = 60;
1812
+ // if (cameraData.fx) {
1813
+ // fov = 2 * Math.atan(0.5 / cameraData.fx) * 180 / Math.PI;
1814
+ // }
1815
+ // cam.fov = fov;
1816
+ // cam.aspect = cameraData.fx / cameraData.fy;
1817
+ // cam.updateProjectionMatrix();
1818
+ // }
1819
+
1820
+ // const cameraSplat = createCameraSplat(cam);
1821
+ // cameraSplats.push(cameraSplat);
1822
+ // cameraParams.push({
1823
+ // position: cam.position.clone(),
1824
+ // quaternion: cam.quaternion.clone(),
1825
+ // fov: cam.fov,
1826
+ // aspect: cam.aspect,
1827
+ // });
1828
+ // scene.add(cameraSplat);
1829
+ // });
1830
+ // console.log(`Loaded ${cameras.length} cameras`);
1831
+ // }
1832
+
1833
+ // // 设置分辨率
1834
+ // if (Array.isArray(resolution) && resolution.length === 3) {
1835
+ // guiOptions.Resolution = `${resolution[0]}x${resolution[1]}x${resolution[2]}`;
1836
+ // }
1837
+
1838
+ // alert("JSON loaded");
1839
+ // } catch (error) {
1840
+ // alert("JSON data processing error: " + error);
1841
+ // console.error("JSON data processing error:", error);
1842
+ // }
1843
+ // };
1844
+ // reader.readAsText(file);
1845
+ // };
1846
+
1847
+ const jsonInput = document.querySelector("#json-input");
1848
+ jsonInput.onchange = (event) => {
1849
+ const files = event.target.files;
1850
+ if (!files || files.length === 0) return;
1851
+ const file = files[0];
1852
+ const reader = new FileReader();
1853
+ reader.onload = function(e) {
1854
+ let jsonData;
1855
+ try {
1856
+ jsonData = JSON.parse(e.target.result);
1857
+ } catch (error) {
1858
+ console.error("JSON parsing error:", error);
1859
+ return;
1860
+ }
1861
+
1862
+ // 检查是否是只加载轨迹
1863
+ const loadTrajectoryOnly = window.loadTrajectoryOnly;
1864
+ window.loadTrajectoryOnly = false; // 重置标志
1865
+
1866
+ if (loadTrajectoryOnly) {
1867
+ // 只加载轨迹:清理所有已有的相机和插值相机
1868
+ cameraSplats.forEach(splat => scene.remove(splat));
1869
+ cameraSplats.length = 0;
1870
+ cameraParams.length = 0;
1871
+ interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
1872
+ interpolatedCamerasSplats.length = 0;
1873
+ } else {
1874
+ // 加载完整JSON:清理所有已有的相���和插值相机
1875
+ cameraSplats.forEach(splat => scene.remove(splat));
1876
+ cameraSplats.length = 0;
1877
+ cameraParams.length = 0;
1878
+ interpolatedCamerasSplats.forEach(splat => scene.remove(splat));
1879
+ interpolatedCamerasSplats.length = 0;
1880
+ }
1881
+
1882
+ try {
1883
+ // 兼容不同命名的字段
1884
+ const imagePrompt = jsonData.image_prompt || jsonData.imagePrompt || null;
1885
+ const textPrompt = jsonData.text_prompt || jsonData.textPrompt || "";
1886
+ const cameras = jsonData.cameras || [];
1887
+ const resolution = jsonData.resolution || [16, 480, 640];
1888
+ const imageIndex = jsonData.image_index || jsonData.imageIndex || 0;
1889
+
1890
+ console.log("Loaded JSON data:", {
1891
+ imagePrompt,
1892
+ textPrompt,
1893
+ cameras: cameras.length,
1894
+ resolution,
1895
+ imageIndex
1896
+ });
1897
+
1898
+ // 处理图像提示(仅在非轨迹加载模式下)
1899
+ if (!loadTrajectoryOnly && imagePrompt) {
1900
+ inputImageBase64 = imagePrompt;
1901
+ console.log("Image prompt loaded");
1902
+ }
1903
+
1904
+ // 设置文本提示(仅在非轨迹加载模式下)
1905
+ if (!loadTrajectoryOnly) {
1906
+ guiOptions.inputTextPrompt = textPrompt;
1907
+ guiOptions.imageIndex = imageIndex;
1908
+ }
1909
+
1910
+ // 处理相机数据
1911
+ if (cameras && cameras.length > 0) {
1912
+ let jsonFirstCamera = null;
1913
+ let jsonFirstPosition = null;
1914
+ let jsonFirstQuaternion = null;
1915
+
1916
+ // 首先获取JSON中第一个相机的位置和四元数
1917
+ if (loadTrajectoryOnly && cameras.length > 0) {
1918
+ const firstCameraData = cameras[0];
1919
+ if (Array.isArray(firstCameraData.position) && firstCameraData.position.length === 3) {
1920
+ jsonFirstPosition = new THREE.Vector3(
1921
+ firstCameraData.position[0],
1922
+ firstCameraData.position[1],
1923
+ firstCameraData.position[2]
1924
+ );
1925
+ }
1926
+ if (Array.isArray(firstCameraData.quaternion) && firstCameraData.quaternion.length === 4) {
1927
+ jsonFirstQuaternion = new THREE.Quaternion(
1928
+ firstCameraData.quaternion[1],
1929
+ firstCameraData.quaternion[2],
1930
+ firstCameraData.quaternion[3],
1931
+ firstCameraData.quaternion[0]
1932
+ );
1933
+ }
1934
+ }
1935
+
1936
+ cameras.forEach((cameraData, index) => {
1937
+ // 解析分辨率
1938
+ let aspect = 1.0;
1939
+ if (Array.isArray(resolution) && resolution.length === 3) {
1940
+ aspect = resolution[2] / resolution[1];
1941
+ } else {
1942
+ aspect = guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1];
1943
+ }
1944
+
1945
+ // 根据加载模式决定FOV
1946
+ let fov = 60;
1947
+ if (loadTrajectoryOnly) {
1948
+ // 轨迹加载:使用GUI中设定的FOV
1949
+ fov = guiOptions.FOV;
1950
+ } else {
1951
+ // 完整JSON加载:使用JSON中的FOV或默认值
1952
+ if (cameraData.fx && cameraData.fy) {
1953
+ fov = 2 * Math.atan(0.5 / cameraData.fx) * 180 / Math.PI;
1954
+ }
1955
+ }
1956
+
1957
+ const cam = new THREE.PerspectiveCamera(fov, aspect);
1958
+
1959
+ // 设置位置和四元数
1960
+ if (Array.isArray(cameraData.position) && cameraData.position.length === 3) {
1961
+ cam.position.set(cameraData.position[0], cameraData.position[1], cameraData.position[2]);
1962
+ }
1963
+
1964
+ if (Array.isArray(cameraData.quaternion) && cameraData.quaternion.length === 4) {
1965
+ // 注意:three.js的顺序是 (x, y, z, w)
1966
+ cam.quaternion.set(
1967
+ cameraData.quaternion[1],
1968
+ cameraData.quaternion[2],
1969
+ cameraData.quaternion[3],
1970
+ cameraData.quaternion[0]
1971
+ );
1972
+ }
1973
+
1974
+ // 轨迹加载:第一个相机强制设置为原点
1975
+ // if (loadTrajectoryOnly && index === 0) {
1976
+ // cam.position.set(0, 0, 0);
1977
+ // cam.quaternion.set(0, 0, 0, 1);
1978
+ // }
1979
+
1980
+ // 轨迹加载:归一化到相对于固定FOV相机的位置
1981
+ if (loadTrajectoryOnly && jsonFirstPosition && jsonFirstQuaternion) {
1982
+ // 参考Python代码的归一化逻辑
1983
+ // 1. 计算JSON第一个相机的c2w矩阵
1984
+ const jsonFirstC2W = new THREE.Matrix4();
1985
+ jsonFirstC2W.compose(jsonFirstPosition, jsonFirstQuaternion, new THREE.Vector3(1, 1, 1));
1986
+
1987
+ // 2. 计算当前相机的c2w矩阵
1988
+ const currentC2W = new THREE.Matrix4();
1989
+ currentC2W.compose(cam.position, cam.quaternion, new THREE.Vector3(1, 1, 1));
1990
+
1991
+ // 3. 计算相对变换:ref_w2c @ current_c2w
1992
+ const refW2C = jsonFirstC2W.clone().invert();
1993
+ const relativeTransform = refW2C.clone().multiply(currentC2W);
1994
+
1995
+ // 4. 将相对变换应用到原点相机上(作为参考)
1996
+ const fixedC2W = new THREE.Matrix4();
1997
+ fixedC2W.compose(new THREE.Vector3(0, 0, 0), new THREE.Quaternion(0, 0, 0, 1), new THREE.Vector3(1, 1, 1));
1998
+
1999
+ const newTransform = fixedC2W.clone().multiply(relativeTransform);
2000
+
2001
+ // 5. 提取新的位置和旋转
2002
+ const newPosition = new THREE.Vector3();
2003
+ const newQuaternion = new THREE.Quaternion();
2004
+ const newScale = new THREE.Vector3();
2005
+ newTransform.decompose(newPosition, newQuaternion, newScale);
2006
+
2007
+ cam.position.copy(newPosition);
2008
+ cam.quaternion.copy(newQuaternion);
2009
+ }
2010
+
2011
+ // 设置FOV和焦距(仅在非轨迹加载模式下)
2012
+ if (!loadTrajectoryOnly && cameraData.fx && cameraData.fy) {
2013
+ cam.fov = fov;
2014
+ cam.aspect = cameraData.fx / cameraData.fy;
2015
+ cam.updateProjectionMatrix();
2016
+ } else if (loadTrajectoryOnly) {
2017
+ // 轨迹加载:使用GUI中设定的FOV和aspect
2018
+ cam.fov = fov;
2019
+ cam.aspect = aspect;
2020
+ cam.updateProjectionMatrix();
2021
+ }
2022
+
2023
+ const cameraSplat = createCameraSplat(cam);
2024
+ cameraSplats.push(cameraSplat);
2025
+ cameraParams.push({
2026
+ position: cam.position.clone(),
2027
+ quaternion: cam.quaternion.clone(),
2028
+ fov: cam.fov,
2029
+ aspect: cam.aspect,
2030
+ });
2031
+ scene.add(cameraSplat);
2032
+ });
2033
+
2034
+ console.log(cameraParams);
2035
+ }
2036
+
2037
+ // 设置分辨率(仅在非轨迹加载模式下)
2038
+ if (!loadTrajectoryOnly && Array.isArray(resolution) && resolution.length === 3) {
2039
+ guiOptions.Resolution = `${resolution[0]}x${resolution[1]}x${resolution[2]}`;
2040
+ }
2041
+
2042
+ // 显示成功消息
2043
+ if (loadTrajectoryOnly) {
2044
+ updateStatus(`Trajectory loaded: ${cameras.length} cameras`, cameraParams.length);
2045
+ } else {
2046
+ }
2047
+ } catch (error) {
2048
+ console.error("JSON data processing error:", error);
2049
+ }
2050
+ };
2051
+ reader.readAsText(file);
2052
+ };
2053
+
2054
+ // =========================
2055
+ // Keyboard Controls
2056
+ // =========================
2057
+ document.addEventListener('keypress', (event) => {
2058
+ if (event.code === 'Space') {
2059
+ if (!fixGenerationFOV) {
2060
+ updateStatus('Please fix Generation FOV first', cameraParams.length);
2061
+ return;
2062
+ }
2063
+ // 记录当前相机的pose
2064
+ const new_camera = camera.clone();
2065
+ new_camera.fov = guiOptions.FOV;
2066
+ new_camera.aspect = guiOptions.Resolution.split('x')[2] / guiOptions.Resolution.split('x')[1];
2067
+ new_camera.updateProjectionMatrix();
2068
+
2069
+ const cameraSplat = createCameraSplat(new_camera);
2070
+ cameraSplats.push(cameraSplat);
2071
+ cameraParams.push({
2072
+ position: new_camera.position.clone(),
2073
+ quaternion: new_camera.quaternion.clone(),
2074
+ fov: new_camera.fov,
2075
+ aspect: new_camera.aspect,
2076
+ });
2077
+ scene.add(cameraSplat);
2078
+
2079
+ updateStatus(`Camera ${cameraParams.length} recorded. Press Space for more or Generate!`, cameraParams.length);
2080
+
2081
+ console.log(new_camera.getFocalLength());
2082
+ }
2083
+ });
2084
+
2085
+ // =========================
2086
+ // Scene Initialization
2087
+ // =========================
2088
+
2089
+ // Initialize status
2090
+ updateStatus('FlashWorld initialized. Configure settings to begin.', 0);
2091
+
2092
+ // Add cube splat to the scene
2093
+ let instructionSplat = createCubeSplat(0.25, [1, 1, 1]);
2094
+ instructionSplat.position.set(0, 0, -1);
2095
+ scene.add(instructionSplat);
2096
+ console.log('Cube splat added to scene');
2097
+
2098
+ // Handle window resize
2099
+ window.addEventListener('resize', () => {
2100
+ console.log('Window resized, updating canvas...');
2101
+ // Update canvas size based on current resolution
2102
+ updateCanvasSize();
2103
+ });
2104
+
2105
+ // =========================
2106
+ // Animation Loop
2107
+ // =========================
2108
+ let lastTime = null;
2109
+
2110
+ renderer.setAnimationLoop(function animate(time) {
2111
+ const deltaTime = time - (lastTime || time);
2112
+ lastTime = time;
2113
+
2114
+ // Rotate the cube splat
2115
+ if (instructionSplat) {
2116
+ // instructionSplat.rotation.x += deltaTime / 4000; // 绕X轴旋转
2117
+ instructionSplat.rotation.y += deltaTime / 5000; // 绕Y轴旋转
2118
+ instructionSplat.rotation.z += deltaTime / 6000; // 绕Z轴旋转
2119
+ }
2120
+
2121
+ // No active playback loop; scrubber directly sets camera
2122
+
2123
+ controls.update(camera);
2124
+ renderer.render(scene, camera);
2125
+
2126
+ });
2127
+
2128
+ </script>
2129
+ </body>
2130
+ </html>
models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .autoencoder_kl_wan import AutoencoderKLWan
2
+ from .transformer_wan import WanTransformer3DModel
3
+ from .reconstruction_model import WANDecoderPixelAligned3DGSReconstructionModel
4
+
5
+ __all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WANDecoderPixelAligned3DGSReconstructionModel"]
models/autoencoder_kl_wan.py ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin
24
+ from diffusers.utils import logging
25
+ from diffusers.utils.accelerate_utils import apply_forward_hook
26
+ from diffusers.models.activations import get_activation
27
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
30
+
31
+ import einops
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ CACHE_T = 2
37
+
38
+ class AvgDown3D(nn.Module):
39
+
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ out_channels,
44
+ factor_t,
45
+ factor_s=1,
46
+ ):
47
+ super().__init__()
48
+ self.in_channels = in_channels
49
+ self.out_channels = out_channels
50
+ self.factor_t = factor_t
51
+ self.factor_s = factor_s
52
+ self.factor = self.factor_t * self.factor_s * self.factor_s
53
+
54
+ assert in_channels * self.factor % out_channels == 0
55
+ self.group_size = in_channels * self.factor // out_channels
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ if not ((x.shape[2] == 1 and self.group_size >= self.factor) or self.factor_t == 1):
59
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t)
60
+ pad = (0, 0, 0, 0, pad_t, 0)
61
+ x = F.pad(x, pad)
62
+ B, C, T, H, W = x.shape
63
+ x = x.view(
64
+ B,
65
+ C,
66
+ T // self.factor_t,
67
+ self.factor_t,
68
+ H // self.factor_s,
69
+ self.factor_s,
70
+ W // self.factor_s,
71
+ self.factor_s,
72
+ )
73
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
74
+ x = x.view(
75
+ B,
76
+ C * self.factor,
77
+ T // self.factor_t,
78
+ H // self.factor_s,
79
+ W // self.factor_s,
80
+ )
81
+ x = x.view(
82
+ B,
83
+ self.out_channels,
84
+ self.group_size,
85
+ T // self.factor_t,
86
+ H // self.factor_s,
87
+ W // self.factor_s,
88
+ )
89
+ x = x.mean(dim=2)
90
+ return x
91
+ else:
92
+ # print(1)
93
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
94
+ pad = (0, 0, 0, 0, pad_t, 0)
95
+ B, C, T, H, W = x.shape
96
+ x = x.view(
97
+ B,
98
+ C,
99
+ T,
100
+ 1,
101
+ H // self.factor_s,
102
+ self.factor_s,
103
+ W // self.factor_s,
104
+ self.factor_s,
105
+ )
106
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
107
+ x = x.view(
108
+ B,
109
+ C * self.factor // self.factor_t,
110
+ T,
111
+ H // self.factor_s,
112
+ W // self.factor_s,
113
+ )
114
+ x = x.view(
115
+ B,
116
+ self.out_channels,
117
+ self.group_size // self.factor_t,
118
+ T,
119
+ H // self.factor_s,
120
+ W // self.factor_s,
121
+ )
122
+ # 因为pad的是0,所以按理说除以factor_t后值才是对的
123
+ x = x.mean(dim=2) / (pad_t + 1)
124
+ return x
125
+
126
+ class DupUp3D(nn.Module):
127
+
128
+ def __init__(
129
+ self,
130
+ in_channels: int,
131
+ out_channels: int,
132
+ factor_t,
133
+ factor_s=1,
134
+ ):
135
+ super().__init__()
136
+ self.in_channels = in_channels
137
+ self.out_channels = out_channels
138
+
139
+ self.factor_t = factor_t
140
+ self.factor_s = factor_s
141
+ self.factor = self.factor_t * self.factor_s * self.factor_s
142
+
143
+ assert out_channels * self.factor % in_channels == 0
144
+ self.repeats = out_channels * self.factor // in_channels
145
+
146
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
147
+ if not (first_chunk and x.shape[2] == 1):
148
+ x = x.repeat_interleave(self.repeats, dim=1)
149
+ x = x.view(
150
+ x.size(0),
151
+ self.out_channels,
152
+ self.factor_t,
153
+ self.factor_s,
154
+ self.factor_s,
155
+ x.size(2),
156
+ x.size(3),
157
+ x.size(4),
158
+ )
159
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
160
+ x = x.view(
161
+ x.size(0),
162
+ self.out_channels,
163
+ x.size(2) * self.factor_t,
164
+ x.size(4) * self.factor_s,
165
+ x.size(6) * self.factor_s,
166
+ )
167
+ if first_chunk:
168
+ x = x[:, :, self.factor_t - 1:, :, :]
169
+ return x
170
+ else:
171
+ # print(1)
172
+ x = x.repeat_interleave(self.repeats // self.factor_t, dim=1)
173
+ x = x.view(
174
+ x.size(0),
175
+ self.out_channels,
176
+ 1,
177
+ self.factor_s,
178
+ self.factor_s,
179
+ x.size(2),
180
+ x.size(3),
181
+ x.size(4),
182
+ )
183
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
184
+ x = x.view(
185
+ x.size(0),
186
+ self.out_channels,
187
+ x.size(2),
188
+ x.size(4) * self.factor_s,
189
+ x.size(6) * self.factor_s,
190
+ )
191
+ return x
192
+
193
+ class WanCausalConv3d(nn.Conv3d):
194
+ r"""
195
+ A custom 3D causal convolution layer with feature caching support.
196
+
197
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
198
+ caching for efficient inference.
199
+
200
+ Args:
201
+ in_channels (int): Number of channels in the input image
202
+ out_channels (int): Number of channels produced by the convolution
203
+ kernel_size (int or tuple): Size of the convolving kernel
204
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
205
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ in_channels: int,
211
+ out_channels: int,
212
+ kernel_size: Union[int, Tuple[int, int, int]],
213
+ stride: Union[int, Tuple[int, int, int]] = 1,
214
+ padding: Union[int, Tuple[int, int, int]] = 0,
215
+ ) -> None:
216
+ super().__init__(
217
+ in_channels=in_channels,
218
+ out_channels=out_channels,
219
+ kernel_size=kernel_size,
220
+ stride=stride,
221
+ padding=padding,
222
+ )
223
+
224
+ # Set up causal padding
225
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
226
+ self.padding = (0, 0, 0)
227
+
228
+ def forward(self, x, cache_x=None):
229
+ padding = list(self._padding)
230
+ if cache_x is not None and self._padding[4] > 0:
231
+ cache_x = cache_x.to(x.device)
232
+ x = torch.cat([cache_x, x], dim=2)
233
+ padding[4] -= cache_x.shape[2]
234
+
235
+ if any(padding):
236
+ x = F.pad(x, padding)
237
+
238
+ # print(x.shape)
239
+ return super().forward(x)
240
+
241
+
242
+ class WanRMS_norm(nn.Module):
243
+ r"""
244
+ A custom RMS normalization layer.
245
+
246
+ Args:
247
+ dim (int): The number of dimensions to normalize over.
248
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
249
+ Default is True.
250
+ images (bool, optional): Whether the input represents image data. Default is True.
251
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
252
+ """
253
+
254
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, weight: bool = True, bias: bool = False) -> None:
255
+ super().__init__()
256
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
257
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
258
+
259
+ self.channel_first = channel_first
260
+ self.scale = dim**0.5
261
+ self.gamma = nn.Parameter(torch.ones(shape)) if weight else 1.0
262
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
263
+
264
+ def forward(self, x):
265
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
266
+
267
+
268
+ class WanUpsample(nn.Upsample):
269
+ r"""
270
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
271
+
272
+ Args:
273
+ x (torch.Tensor): Input tensor to be upsampled.
274
+
275
+ Returns:
276
+ torch.Tensor: Upsampled tensor with the same data type as the input.
277
+ """
278
+
279
+ def forward(self, x):
280
+ return super().forward(x.float()).type_as(x)
281
+
282
+
283
+ class WanResample(nn.Module):
284
+ r"""
285
+ A custom resampling module for 2D and 3D data.
286
+
287
+ Args:
288
+ dim (int): The number of input/output channels.
289
+ mode (str): The resampling mode. Must be one of:
290
+ - 'none': No resampling (identity operation).
291
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
292
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
293
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
294
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
295
+ """
296
+
297
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
298
+ super().__init__()
299
+ self.dim = dim
300
+ self.mode = mode
301
+
302
+ # default to dim //2
303
+ if upsample_out_dim is None:
304
+ upsample_out_dim = dim // 2
305
+
306
+ # layers
307
+ if mode == "upsample2d":
308
+ self.resample = nn.Sequential(
309
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
310
+ )
311
+ elif mode == "upsample3d":
312
+ self.resample = nn.Sequential(
313
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
314
+ )
315
+ self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
316
+
317
+ elif mode == "downsample2d":
318
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
319
+ elif mode == "downsample3d":
320
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
321
+ self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
322
+
323
+ else:
324
+ self.resample = nn.Identity()
325
+
326
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
327
+ b, c, t, h, w = x.size()
328
+ if self.mode == "upsample3d":
329
+ if feat_cache is not None:
330
+ idx = feat_idx[0]
331
+ if feat_cache[idx] is None:
332
+ feat_cache[idx] = "Rep"
333
+ feat_idx[0] += 1
334
+ else:
335
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
336
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
337
+ # cache last frame of last two chunk
338
+ cache_x = torch.cat(
339
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
340
+ )
341
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
342
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
343
+ if feat_cache[idx] == "Rep":
344
+ x = self.time_conv(x)
345
+ else:
346
+ x = self.time_conv(x, feat_cache[idx])
347
+ feat_cache[idx] = cache_x
348
+ feat_idx[0] += 1
349
+
350
+ x = x.reshape(b, 2, c, t, h, w)
351
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
352
+ x = x.reshape(b, c, t * 2, h, w)
353
+ t = x.shape[2]
354
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
355
+ x = self.resample(x)
356
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
357
+
358
+ if self.mode == "downsample3d":
359
+ if feat_cache is not None:
360
+ idx = feat_idx[0]
361
+ if feat_cache[idx] is None:
362
+ feat_cache[idx] = x.clone()
363
+ feat_idx[0] += 1
364
+ else:
365
+ cache_x = x[:, :, -1:, :, :].clone()
366
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
367
+ feat_cache[idx] = cache_x
368
+ feat_idx[0] += 1
369
+ return x
370
+
371
+
372
+ class WanResidualBlock(nn.Module):
373
+ r"""
374
+ A custom residual block module.
375
+
376
+ Args:
377
+ in_dim (int): Number of input channels.
378
+ out_dim (int): Number of output channels.
379
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
380
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ in_dim: int,
386
+ out_dim: int,
387
+ dropout: float = 0.0,
388
+ non_linearity: str = "silu",
389
+ ) -> None:
390
+ super().__init__()
391
+ self.in_dim = in_dim
392
+ self.out_dim = out_dim
393
+ self.nonlinearity = get_activation(non_linearity)
394
+
395
+ # layers
396
+ self.norm1 = WanRMS_norm(in_dim, images=False)
397
+ self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
398
+ self.norm2 = WanRMS_norm(out_dim, images=False)
399
+ self.dropout = nn.Dropout(dropout)
400
+ self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
401
+ self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
402
+
403
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
404
+ # Apply shortcut connection
405
+ h = self.conv_shortcut(x)
406
+
407
+ # First normalization and activation
408
+ x = self.norm1(x)
409
+ x = self.nonlinearity(x)
410
+
411
+ if feat_cache is not None:
412
+ idx = feat_idx[0]
413
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
414
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
415
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
416
+
417
+ x = self.conv1(x, feat_cache[idx])
418
+ feat_cache[idx] = cache_x
419
+ feat_idx[0] += 1
420
+ else:
421
+ x = self.conv1(x)
422
+
423
+ # Second normalization and activation
424
+ x = self.norm2(x)
425
+ x = self.nonlinearity(x)
426
+
427
+ # Dropout
428
+ x = self.dropout(x)
429
+
430
+ if feat_cache is not None:
431
+ idx = feat_idx[0]
432
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
433
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
434
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
435
+
436
+ x = self.conv2(x, feat_cache[idx])
437
+ feat_cache[idx] = cache_x
438
+ feat_idx[0] += 1
439
+ else:
440
+ x = self.conv2(x)
441
+
442
+ # Add residual connection
443
+ return h.add_(x)
444
+
445
+
446
+ class WanAttentionBlock(nn.Module):
447
+ r"""
448
+ Causal self-attention with a single head.
449
+
450
+ Args:
451
+ dim (int): The number of channels in the input tensor.
452
+ """
453
+
454
+ def __init__(self, dim):
455
+ super().__init__()
456
+ self.dim = dim
457
+
458
+ # layers
459
+ self.norm = WanRMS_norm(dim)
460
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
461
+ self.proj = nn.Conv2d(dim, dim, 1)
462
+
463
+ def forward(self, x):
464
+ identity = x
465
+ batch_size, channels, time, height, width = x.size()
466
+
467
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
468
+ x = self.norm(x)
469
+
470
+ # compute query, key, value
471
+ qkv = self.to_qkv(x)
472
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
473
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
474
+ q, k, v = qkv.chunk(3, dim=-1)
475
+
476
+ # apply attention
477
+ x = F.scaled_dot_product_attention(q, k, v)
478
+
479
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
480
+
481
+ # output projection
482
+ x = self.proj(x)
483
+
484
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
485
+ x = x.view(batch_size, time, channels, height, width)
486
+ x = x.permute(0, 2, 1, 3, 4)
487
+
488
+ return identity.add_(x)
489
+
490
+
491
+ class WanMidBlock(nn.Module):
492
+ """
493
+ Middle block for WanVAE encoder and decoder.
494
+
495
+ Args:
496
+ dim (int): Number of input/output channels.
497
+ dropout (float): Dropout rate.
498
+ non_linearity (str): Type of non-linearity to use.
499
+ """
500
+
501
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
502
+ super().__init__()
503
+ self.dim = dim
504
+
505
+ # Create the components
506
+ resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
507
+ attentions = []
508
+ for _ in range(num_layers):
509
+ attentions.append(WanAttentionBlock(dim))
510
+ resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
511
+ self.attentions = nn.ModuleList(attentions)
512
+ self.resnets = nn.ModuleList(resnets)
513
+
514
+ self.gradient_checkpointing = False
515
+
516
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
517
+ # First residual block
518
+ x = self.resnets[0](x, feat_cache, feat_idx)
519
+
520
+ # Process through attention and residual blocks
521
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
522
+ if attn is not None:
523
+ x = attn(x)
524
+
525
+ x = resnet(x, feat_cache, feat_idx)
526
+
527
+ return x
528
+
529
+
530
+ class WanResidualDownBlock(nn.Module):
531
+
532
+ def __init__(self,
533
+ in_dim,
534
+ out_dim,
535
+ dropout,
536
+ num_res_blocks,
537
+ temperal_downsample=False,
538
+ down_flag=False):
539
+ super().__init__()
540
+
541
+ # Shortcut path with downsample
542
+ self.avg_shortcut = AvgDown3D(
543
+ in_dim,
544
+ out_dim,
545
+ factor_t=2 if temperal_downsample else 1,
546
+ factor_s=2 if down_flag else 1,
547
+ )
548
+
549
+ # Main path with residual blocks and downsample
550
+ resnets = []
551
+ for _ in range(num_res_blocks):
552
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
553
+ in_dim = out_dim
554
+ self.resnets = nn.ModuleList(resnets)
555
+
556
+ # Add the final downsample block
557
+ if down_flag:
558
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
559
+ self.downsampler = WanResample(out_dim, mode=mode)
560
+ else:
561
+ self.downsampler = None
562
+
563
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
564
+ x_copy = x.clone()
565
+ for resnet in self.resnets:
566
+ x = resnet(x, feat_cache, feat_idx)
567
+ if self.downsampler is not None:
568
+ x = self.downsampler(x, feat_cache, feat_idx)
569
+
570
+ return self.avg_shortcut(x_copy).add_(x)
571
+
572
+ class WanEncoder3d(nn.Module):
573
+ r"""
574
+ A 3D encoder module.
575
+
576
+ Args:
577
+ dim (int): The base number of channels in the first layer.
578
+ z_dim (int): The dimensionality of the latent space.
579
+ dim_mult (list of int): Multipliers for the number of channels in each block.
580
+ num_res_blocks (int): Number of residual blocks in each block.
581
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
582
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
583
+ dropout (float): Dropout rate for the dropout layers.
584
+ non_linearity (str): Type of non-linearity to use.
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ in_channels: int = 3,
590
+ dim=128,
591
+ z_dim=4,
592
+ dim_mult=[1, 2, 4, 4],
593
+ num_res_blocks=2,
594
+ attn_scales=[],
595
+ temperal_downsample=[True, True, False],
596
+ dropout=0.0,
597
+ non_linearity: str = "silu",
598
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
599
+ ):
600
+ super().__init__()
601
+ self.dim = dim
602
+ self.z_dim = z_dim
603
+ self.dim_mult = dim_mult
604
+ self.num_res_blocks = num_res_blocks
605
+ self.attn_scales = attn_scales
606
+ self.temperal_downsample = temperal_downsample
607
+ self.nonlinearity = get_activation(non_linearity)
608
+
609
+ # dimensions
610
+ dims = [dim * u for u in [1] + dim_mult]
611
+ scale = 1.0
612
+
613
+ # init block
614
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
615
+
616
+ # downsample blocks
617
+ self.down_blocks = nn.ModuleList([])
618
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
619
+ # residual (+attention) blocks
620
+ if is_residual:
621
+ self.down_blocks.append(
622
+ WanResidualDownBlock(
623
+ in_dim,
624
+ out_dim,
625
+ dropout,
626
+ num_res_blocks,
627
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
628
+ down_flag=i != len(dim_mult) - 1,
629
+ )
630
+ )
631
+ else:
632
+ for _ in range(num_res_blocks):
633
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
634
+ if scale in attn_scales:
635
+ self.down_blocks.append(WanAttentionBlock(out_dim))
636
+ in_dim = out_dim
637
+
638
+ # downsample block
639
+ if i != len(dim_mult) - 1:
640
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
641
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
642
+ scale /= 2.0
643
+
644
+ # middle blocks
645
+ self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
646
+
647
+ # output blocks
648
+ self.norm_out = WanRMS_norm(out_dim, images=False)
649
+ self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
650
+
651
+ self.gradient_checkpointing = False
652
+
653
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
654
+ assert x.shape[2] == 1
655
+ if feat_cache is not None:
656
+ idx = feat_idx[0]
657
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
658
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
659
+ # cache last frame of last two chunk
660
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
661
+ x = self.conv_in(x, feat_cache[idx])
662
+ feat_cache[idx] = cache_x
663
+ feat_idx[0] += 1
664
+ else:
665
+ x = self.conv_in(x)
666
+
667
+ ## downsamples
668
+ for layer in self.down_blocks:
669
+ if feat_cache is not None:
670
+ x = layer(x, feat_cache, feat_idx)
671
+ else:
672
+ x = layer(x)
673
+
674
+ ## middle
675
+ x = self.mid_block(x, feat_cache, feat_idx)
676
+
677
+ ## head
678
+ x = self.norm_out(x)
679
+ x = self.nonlinearity(x)
680
+ if feat_cache is not None:
681
+ idx = feat_idx[0]
682
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
683
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
684
+ # cache last frame of last two chunk
685
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
686
+ x = self.conv_out(x, feat_cache[idx])
687
+ feat_cache[idx] = cache_x
688
+ feat_idx[0] += 1
689
+ else:
690
+ x = self.conv_out(x)
691
+ return x
692
+
693
+ class WanResidualUpBlock(nn.Module):
694
+ """
695
+ A block that handles upsampling for the WanVAE decoder.
696
+
697
+ Args:
698
+ in_dim (int): Input dimension
699
+ out_dim (int): Output dimension
700
+ num_res_blocks (int): Number of residual blocks
701
+ dropout (float): Dropout rate
702
+ temperal_upsample (bool): Whether to upsample on temporal dimension
703
+ up_flag (bool): Whether to upsample or not
704
+ non_linearity (str): Type of non-linearity to use
705
+ """
706
+
707
+ def __init__(
708
+ self,
709
+ in_dim: int,
710
+ out_dim: int,
711
+ num_res_blocks: int,
712
+ dropout: float = 0.0,
713
+ temperal_upsample: bool = False,
714
+ up_flag: bool = False,
715
+ non_linearity: str = "silu",
716
+ ):
717
+ super().__init__()
718
+ self.in_dim = in_dim
719
+ self.out_dim = out_dim
720
+
721
+ if up_flag:
722
+ self.avg_shortcut = DupUp3D(
723
+ in_dim,
724
+ out_dim,
725
+ factor_t=2 if temperal_upsample else 1,
726
+ factor_s=2,
727
+ )
728
+ else:
729
+ self.avg_shortcut = None
730
+
731
+ # create residual blocks
732
+ resnets = []
733
+ current_dim = in_dim
734
+ for _ in range(num_res_blocks + 1):
735
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
736
+ current_dim = out_dim
737
+
738
+ self.resnets = nn.ModuleList(resnets)
739
+
740
+ # Add upsampling layer if needed
741
+ if up_flag:
742
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
743
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
744
+ else:
745
+ self.upsampler = None
746
+
747
+ self.gradient_checkpointing = False
748
+
749
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
750
+ """
751
+ Forward pass through the upsampling block.
752
+
753
+ Args:
754
+ x (torch.Tensor): Input tensor
755
+ feat_cache (list, optional): Feature cache for causal convolutions
756
+ feat_idx (list, optional): Feature index for cache management
757
+
758
+ Returns:
759
+ torch.Tensor: Output tensor
760
+ """
761
+ x_copy = x.clone()
762
+
763
+ for resnet in self.resnets:
764
+ if feat_cache is not None:
765
+ x = resnet(x, feat_cache, feat_idx)
766
+ else:
767
+ x = resnet(x)
768
+
769
+ if self.upsampler is not None:
770
+ if feat_cache is not None:
771
+ x = self.upsampler(x, feat_cache, feat_idx)
772
+ else:
773
+ x = self.upsampler(x)
774
+
775
+ if self.avg_shortcut is not None:
776
+ # print(x.shape, x_copy.shape, self.avg_shortcut(x_copy, first_chunk=first_chunk).shape)
777
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
778
+
779
+ return x
780
+
781
+ class WanUpBlock(nn.Module):
782
+ """
783
+ A block that handles upsampling for the WanVAE decoder.
784
+
785
+ Args:
786
+ in_dim (int): Input dimension
787
+ out_dim (int): Output dimension
788
+ num_res_blocks (int): Number of residual blocks
789
+ dropout (float): Dropout rate
790
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
791
+ non_linearity (str): Type of non-linearity to use
792
+ """
793
+
794
+ def __init__(
795
+ self,
796
+ in_dim: int,
797
+ out_dim: int,
798
+ num_res_blocks: int,
799
+ dropout: float = 0.0,
800
+ upsample_mode: Optional[str] = None,
801
+ non_linearity: str = "silu",
802
+ ):
803
+ super().__init__()
804
+ self.in_dim = in_dim
805
+ self.out_dim = out_dim
806
+
807
+ # Create layers list
808
+ resnets = []
809
+ # Add residual blocks and attention if needed
810
+ current_dim = in_dim
811
+ for _ in range(num_res_blocks + 1):
812
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
813
+ current_dim = out_dim
814
+
815
+ self.resnets = nn.ModuleList(resnets)
816
+
817
+ # Add upsampling layer if needed
818
+ self.upsamplers = None
819
+ if upsample_mode is not None:
820
+ self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
825
+ """
826
+ Forward pass through the upsampling block.
827
+
828
+ Args:
829
+ x (torch.Tensor): Input tensor
830
+ feat_cache (list, optional): Feature cache for causal convolutions
831
+ feat_idx (list, optional): Feature index for cache management
832
+
833
+ Returns:
834
+ torch.Tensor: Output tensor
835
+ """
836
+ for resnet in self.resnets:
837
+ if feat_cache is not None:
838
+ x = resnet(x, feat_cache, feat_idx)
839
+ else:
840
+ x = resnet(x)
841
+
842
+ if self.upsamplers is not None:
843
+ if feat_cache is not None:
844
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
845
+ else:
846
+ x = self.upsamplers[0](x)
847
+ return x
848
+
849
+
850
+ class WanDecoder3d(nn.Module):
851
+ r"""
852
+ A 3D decoder module.
853
+
854
+ Args:
855
+ dim (int): The base number of channels in the first layer.
856
+ z_dim (int): The dimensionality of the latent space.
857
+ dim_mult (list of int): Multipliers for the number of channels in each block.
858
+ num_res_blocks (int): Number of residual blocks in each block.
859
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
860
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
861
+ dropout (float): Dropout rate for the dropout layers.
862
+ non_linearity (str): Type of non-linearity to use.
863
+ """
864
+
865
+ def __init__(
866
+ self,
867
+ dim=128,
868
+ z_dim=4,
869
+ dim_mult=[1, 2, 4, 4],
870
+ num_res_blocks=2,
871
+ attn_scales=[],
872
+ temperal_upsample=[False, True, True],
873
+ dropout=0.0,
874
+ non_linearity: str = "silu",
875
+ out_channels: int = 3,
876
+ is_residual: bool = False,
877
+ ):
878
+ super().__init__()
879
+ self.dim = dim
880
+ self.z_dim = z_dim
881
+ self.dim_mult = dim_mult
882
+ self.num_res_blocks = num_res_blocks
883
+ self.attn_scales = attn_scales
884
+ self.temperal_upsample = temperal_upsample
885
+
886
+ self.nonlinearity = get_activation(non_linearity)
887
+
888
+ # dimensions
889
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
890
+
891
+ # init block
892
+ self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
893
+
894
+ # middle blocks
895
+ self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
896
+
897
+ # upsample blocks
898
+ self.up_blocks = nn.ModuleList([])
899
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
900
+ # residual (+attention) blocks
901
+ if i > 0 and not is_residual:
902
+ # wan vae 2.1
903
+ in_dim = in_dim // 2
904
+
905
+ # determine if we need upsampling
906
+ up_flag = i != len(dim_mult) - 1
907
+ # determine upsampling mode, if not upsampling, set to None
908
+ upsample_mode = None
909
+ if up_flag and temperal_upsample[i]:
910
+ upsample_mode = "upsample3d"
911
+ elif up_flag:
912
+ upsample_mode = "upsample2d"
913
+ # Create and add the upsampling block
914
+ if is_residual:
915
+ up_block = WanResidualUpBlock(
916
+ in_dim=in_dim,
917
+ out_dim=out_dim,
918
+ num_res_blocks=num_res_blocks,
919
+ dropout=dropout,
920
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
921
+ up_flag= up_flag,
922
+ non_linearity=non_linearity,
923
+ )
924
+ else:
925
+ up_block = WanUpBlock(
926
+ in_dim=in_dim,
927
+ out_dim=out_dim,
928
+ num_res_blocks=num_res_blocks,
929
+ dropout=dropout,
930
+ upsample_mode=upsample_mode,
931
+ non_linearity=non_linearity,
932
+ )
933
+ self.up_blocks.append(up_block)
934
+
935
+ # output blocks
936
+ self.norm_out = WanRMS_norm(out_dim, images=False)
937
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
938
+
939
+ self.gradient_checkpointing = False
940
+
941
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
942
+ assert x.shape[2] == 1
943
+ ## conv1
944
+ if feat_cache is not None:
945
+ idx = feat_idx[0]
946
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
947
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
948
+ # cache last frame of last two chunk
949
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
950
+ x = self.conv_in(x, feat_cache[idx])
951
+ feat_cache[idx] = cache_x
952
+ feat_idx[0] += 1
953
+ else:
954
+ x = self.conv_in(x)
955
+
956
+ ## middle
957
+ x = self.mid_block(x, feat_cache, feat_idx)
958
+
959
+ ## upsamples
960
+ for up_block in self.up_blocks:
961
+ x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
962
+
963
+ ## head
964
+ x = self.norm_out(x)
965
+ x = self.nonlinearity(x)
966
+ if feat_cache is not None:
967
+ idx = feat_idx[0]
968
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
969
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
970
+ # cache last frame of last two chunk
971
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
972
+ x = self.conv_out(x, feat_cache[idx])
973
+ feat_cache[idx] = cache_x
974
+ feat_idx[0] += 1
975
+ else:
976
+ x = self.conv_out(x)
977
+ return x
978
+
979
+
980
+ def patchify(x, patch_size):
981
+ # YiYi TODO: refactor this
982
+ from einops import rearrange
983
+ if patch_size == 1:
984
+ return x
985
+ if x.dim() == 4:
986
+ x = rearrange(
987
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
988
+ elif x.dim() == 5:
989
+ x = rearrange(
990
+ x,
991
+ "b c f (h q) (w r) -> b (c r q) f h w",
992
+ q=patch_size,
993
+ r=patch_size,
994
+ )
995
+ else:
996
+ raise ValueError(f"Invalid input shape: {x.shape}")
997
+
998
+ return x
999
+
1000
+
1001
+ def unpatchify(x, patch_size):
1002
+ # YiYi TODO: refactor this
1003
+ from einops import rearrange
1004
+ if patch_size == 1:
1005
+ return x
1006
+
1007
+ if x.dim() == 4:
1008
+ x = rearrange(
1009
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
1010
+ elif x.dim() == 5:
1011
+ x = rearrange(
1012
+ x,
1013
+ "b (c r q) f h w -> b c f (h q) (w r)",
1014
+ q=patch_size,
1015
+ r=patch_size,
1016
+ )
1017
+ return x
1018
+
1019
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1020
+ r"""
1021
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
1022
+ Introduced in [Wan 2.1].
1023
+
1024
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1025
+ for all models (such as downloading or saving).
1026
+ """
1027
+
1028
+ _supports_gradient_checkpointing = False
1029
+
1030
+ @register_to_config
1031
+ def __init__(
1032
+ self,
1033
+ base_dim: int = 96,
1034
+ decoder_base_dim: Optional[int] = None,
1035
+ z_dim: int = 16,
1036
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
1037
+ num_res_blocks: int = 2,
1038
+ attn_scales: List[float] = [],
1039
+ temperal_downsample: List[bool] = [False, True, True],
1040
+ dropout: float = 0.0,
1041
+ latents_mean: List[float] = [
1042
+ -0.7571,
1043
+ -0.7089,
1044
+ -0.9113,
1045
+ 0.1075,
1046
+ -0.1745,
1047
+ 0.9653,
1048
+ -0.1517,
1049
+ 1.5508,
1050
+ 0.4134,
1051
+ -0.0715,
1052
+ 0.5517,
1053
+ -0.3632,
1054
+ -0.1922,
1055
+ -0.9497,
1056
+ 0.2503,
1057
+ -0.2921,
1058
+ ],
1059
+ latents_std: List[float] = [
1060
+ 2.8184,
1061
+ 1.4541,
1062
+ 2.3275,
1063
+ 2.6558,
1064
+ 1.2196,
1065
+ 1.7708,
1066
+ 2.6052,
1067
+ 2.0743,
1068
+ 3.2687,
1069
+ 2.1526,
1070
+ 2.8652,
1071
+ 1.5579,
1072
+ 1.6382,
1073
+ 1.1253,
1074
+ 2.8251,
1075
+ 1.9160,
1076
+ ],
1077
+ is_residual: bool = False,
1078
+ in_channels: int = 3,
1079
+ out_channels: int = 3,
1080
+ patch_size: Optional[int] = None,
1081
+ scale_factor_temporal: Optional[int] = 4,
1082
+ scale_factor_spatial: Optional[int] = 8,
1083
+ clip_output: bool = True,
1084
+ ) -> None:
1085
+ super().__init__()
1086
+
1087
+ self.z_dim = z_dim
1088
+ self.temperal_downsample = temperal_downsample
1089
+ self.temperal_upsample = temperal_downsample[::-1]
1090
+
1091
+ if decoder_base_dim is None:
1092
+ decoder_base_dim = base_dim
1093
+
1094
+ self.encoder = WanEncoder3d(
1095
+ in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
1096
+ )
1097
+ self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
1098
+ self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
1099
+
1100
+ self.decoder = WanDecoder3d(
1101
+ dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
1102
+ )
1103
+
1104
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1105
+
1106
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
1107
+ # to perform decoding of a single video latent at a time.
1108
+ self.use_slicing = False
1109
+
1110
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
1111
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
1112
+ # intermediate tiles together, the memory requirement can be lowered.
1113
+ self.use_tiling = False
1114
+
1115
+ # The minimal tile height and width for spatial tiling to be used
1116
+ self.tile_sample_min_height = 256
1117
+ self.tile_sample_min_width = 256
1118
+
1119
+ # The minimal distance between two spatial tiles
1120
+ self.tile_sample_stride_height = 192
1121
+ self.tile_sample_stride_width = 192
1122
+
1123
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
1124
+ self._cached_conv_counts = {
1125
+ "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
1126
+ if self.decoder is not None
1127
+ else 0,
1128
+ "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
1129
+ if self.encoder is not None
1130
+ else 0,
1131
+ }
1132
+
1133
+ def enable_tiling(
1134
+ self,
1135
+ tile_sample_min_height: Optional[int] = None,
1136
+ tile_sample_min_width: Optional[int] = None,
1137
+ tile_sample_stride_height: Optional[float] = None,
1138
+ tile_sample_stride_width: Optional[float] = None,
1139
+ ) -> None:
1140
+ r"""
1141
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1142
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1143
+ processing larger images.
1144
+
1145
+ Args:
1146
+ tile_sample_min_height (`int`, *optional*):
1147
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1148
+ tile_sample_min_width (`int`, *optional*):
1149
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1150
+ tile_sample_stride_height (`int`, *optional*):
1151
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1152
+ no tiling artifacts produced across the height dimension.
1153
+ tile_sample_stride_width (`int`, *optional*):
1154
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
1155
+ artifacts produced across the width dimension.
1156
+ """
1157
+ self.use_tiling = True
1158
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1159
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1160
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1161
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1162
+
1163
+ def disable_tiling(self) -> None:
1164
+ r"""
1165
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1166
+ decoding in one step.
1167
+ """
1168
+ self.use_tiling = False
1169
+
1170
+ def enable_slicing(self) -> None:
1171
+ r"""
1172
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1173
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1174
+ """
1175
+ self.use_slicing = True
1176
+
1177
+ def disable_slicing(self) -> None:
1178
+ r"""
1179
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1180
+ decoding in one step.
1181
+ """
1182
+ self.use_slicing = False
1183
+
1184
+ def clear_cache(self):
1185
+ # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
1186
+ self._conv_num = self._cached_conv_counts["decoder"]
1187
+ self._conv_idx = [0]
1188
+ self._feat_map = [None] * self._conv_num
1189
+ # cache encode
1190
+ self._enc_conv_num = self._cached_conv_counts["encoder"]
1191
+ self._enc_conv_idx = [0]
1192
+ self._enc_feat_map = [None] * self._enc_conv_num
1193
+
1194
+ def _encode(self, x: torch.Tensor):
1195
+ _, _, num_frame, height, width = x.shape
1196
+
1197
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1198
+ return self.tiled_encode(x)
1199
+
1200
+ self.clear_cache()
1201
+ if self.config.patch_size is not None:
1202
+ x = patchify(x, patch_size=self.config.patch_size)
1203
+ iter_ = 1 + (num_frame - 1) // 4
1204
+ self._enc_feat_map = None if iter_ == 1 else self._enc_feat_map
1205
+ for i in range(iter_):
1206
+ self._enc_conv_idx = [0]
1207
+ if i == 0:
1208
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1209
+ else:
1210
+ out_ = self.encoder(
1211
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
1212
+ feat_cache=self._enc_feat_map,
1213
+ feat_idx=self._enc_conv_idx,
1214
+ )
1215
+ out = torch.cat([out, out_], 2)
1216
+
1217
+ enc = self.quant_conv(out)
1218
+ self.clear_cache()
1219
+ return enc
1220
+
1221
+ @apply_forward_hook
1222
+ def encode(
1223
+ self, x: torch.Tensor, return_dict: bool = True
1224
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1225
+ r"""
1226
+ Encode a batch of images into latents.
1227
+
1228
+ Args:
1229
+ x (`torch.Tensor`): Input batch of images.
1230
+ return_dict (`bool`, *optional*, defaults to `True`):
1231
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1232
+
1233
+ Returns:
1234
+ The latent representations of the encoded videos. If `return_dict` is True, a
1235
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1236
+ """
1237
+ if self.use_slicing and x.shape[0] > 1:
1238
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1239
+ h = torch.cat(encoded_slices)
1240
+ else:
1241
+ h = self._encode(x)
1242
+ posterior = DiagonalGaussianDistribution(h)
1243
+
1244
+ if not return_dict:
1245
+ return (posterior,)
1246
+ return AutoencoderKLOutput(latent_dist=posterior)
1247
+
1248
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
1249
+ _, _, num_frame, height, width = z.shape
1250
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1251
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1252
+
1253
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1254
+ return self.tiled_decode(z, return_dict=return_dict)
1255
+
1256
+ self.clear_cache()
1257
+ self._feat_map = None if num_frame == 1 else self._feat_map
1258
+ x = self.post_quant_conv(z)
1259
+ for i in range(num_frame):
1260
+ self._conv_idx = [0]
1261
+ if i == 0:
1262
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
1263
+ else:
1264
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
1265
+ out = torch.cat([out, out_], 2)
1266
+
1267
+ if self.config.clip_output:
1268
+ out = torch.clamp(out, min=-1.0, max=1.0)
1269
+ if self.config.patch_size is not None:
1270
+ out = unpatchify(out, patch_size=self.config.patch_size)
1271
+ self.clear_cache()
1272
+ if not return_dict:
1273
+ return (out,)
1274
+
1275
+ return DecoderOutput(sample=out)
1276
+
1277
+ @apply_forward_hook
1278
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1279
+ r"""
1280
+ Decode a batch of images.
1281
+
1282
+ Args:
1283
+ z (`torch.Tensor`): Input batch of latent vectors.
1284
+ return_dict (`bool`, *optional*, defaults to `True`):
1285
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1286
+
1287
+ Returns:
1288
+ [`~models.vae.DecoderOutput`] or `tuple`:
1289
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1290
+ returned.
1291
+ """
1292
+ if self.use_slicing and z.shape[0] > 1:
1293
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1294
+ decoded = torch.cat(decoded_slices)
1295
+ else:
1296
+ decoded = self._decode(z).sample
1297
+
1298
+ if not return_dict:
1299
+ return (decoded,)
1300
+ return DecoderOutput(sample=decoded)
1301
+
1302
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1303
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
1304
+ for y in range(blend_extent):
1305
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1306
+ y / blend_extent
1307
+ )
1308
+ return b
1309
+
1310
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1311
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
1312
+ for x in range(blend_extent):
1313
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1314
+ x / blend_extent
1315
+ )
1316
+ return b
1317
+
1318
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1319
+ r"""Encode a batch of images using a tiled encoder.
1320
+
1321
+ Args:
1322
+ x (`torch.Tensor`): Input batch of videos.
1323
+
1324
+ Returns:
1325
+ `torch.Tensor`:
1326
+ The latent representation of the encoded videos.
1327
+ """
1328
+ _, _, num_frames, height, width = x.shape
1329
+ latent_height = height // self.spatial_compression_ratio
1330
+ latent_width = width // self.spatial_compression_ratio
1331
+
1332
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1333
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1334
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1335
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1336
+
1337
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1338
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1339
+
1340
+ # Split x into overlapping tiles and encode them separately.
1341
+ # The tiles have an overlap to avoid seams between tiles.
1342
+ rows = []
1343
+ for i in range(0, height, self.tile_sample_stride_height):
1344
+ row = []
1345
+ for j in range(0, width, self.tile_sample_stride_width):
1346
+ self.clear_cache()
1347
+ time = []
1348
+ frame_range = 1 + (num_frames - 1) // 4
1349
+ for k in range(frame_range):
1350
+ self._enc_conv_idx = [0]
1351
+ if k == 0:
1352
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1353
+ else:
1354
+ tile = x[
1355
+ :,
1356
+ :,
1357
+ 1 + 4 * (k - 1) : 1 + 4 * k,
1358
+ i : i + self.tile_sample_min_height,
1359
+ j : j + self.tile_sample_min_width,
1360
+ ]
1361
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1362
+ tile = self.quant_conv(tile)
1363
+ time.append(tile)
1364
+ row.append(torch.cat(time, dim=2))
1365
+ rows.append(row)
1366
+ self.clear_cache()
1367
+
1368
+ result_rows = []
1369
+ for i, row in enumerate(rows):
1370
+ result_row = []
1371
+ for j, tile in enumerate(row):
1372
+ # blend the above tile and the left tile
1373
+ # to the current tile and add the current tile to the result row
1374
+ if i > 0:
1375
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1376
+ if j > 0:
1377
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1378
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1379
+ result_rows.append(torch.cat(result_row, dim=-1))
1380
+
1381
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1382
+ return enc
1383
+
1384
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1385
+ r"""
1386
+ Decode a batch of images using a tiled decoder.
1387
+
1388
+ Args:
1389
+ z (`torch.Tensor`): Input batch of latent vectors.
1390
+ return_dict (`bool`, *optional*, defaults to `True`):
1391
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1392
+
1393
+ Returns:
1394
+ [`~models.vae.DecoderOutput`] or `tuple`:
1395
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1396
+ returned.
1397
+ """
1398
+ _, _, num_frames, height, width = z.shape
1399
+ sample_height = height * self.spatial_compression_ratio
1400
+ sample_width = width * self.spatial_compression_ratio
1401
+
1402
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1403
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1404
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1405
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1406
+
1407
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1408
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1409
+
1410
+ # Split z into overlapping tiles and decode them separately.
1411
+ # The tiles have an overlap to avoid seams between tiles.
1412
+ rows = []
1413
+ for i in range(0, height, tile_latent_stride_height):
1414
+ row = []
1415
+ for j in range(0, width, tile_latent_stride_width):
1416
+ self.clear_cache()
1417
+ time = []
1418
+ for k in range(num_frames):
1419
+ self._conv_idx = [0]
1420
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1421
+ tile = self.post_quant_conv(tile)
1422
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1423
+ time.append(decoded)
1424
+ row.append(torch.cat(time, dim=2))
1425
+ rows.append(row)
1426
+ self.clear_cache()
1427
+
1428
+ result_rows = []
1429
+ for i, row in enumerate(rows):
1430
+ result_row = []
1431
+ for j, tile in enumerate(row):
1432
+ # blend the above tile and the left tile
1433
+ # to the current tile and add the current tile to the result row
1434
+ if i > 0:
1435
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1436
+ if j > 0:
1437
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1438
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1439
+ result_rows.append(torch.cat(result_row, dim=-1))
1440
+
1441
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1442
+
1443
+ if not return_dict:
1444
+ return (dec,)
1445
+ return DecoderOutput(sample=dec)
1446
+
1447
+ def forward(
1448
+ self,
1449
+ sample: torch.Tensor,
1450
+ sample_posterior: bool = False,
1451
+ return_dict: bool = True,
1452
+ generator: Optional[torch.Generator] = None,
1453
+ ) -> Union[DecoderOutput, torch.Tensor]:
1454
+ """
1455
+ Args:
1456
+ sample (`torch.Tensor`): Input sample.
1457
+ return_dict (`bool`, *optional*, defaults to `True`):
1458
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1459
+ """
1460
+ x = sample
1461
+ posterior = self.encode(x).latent_dist
1462
+ if sample_posterior:
1463
+ z = posterior.sample(generator=generator)
1464
+ else:
1465
+ z = posterior.mode()
1466
+ dec = self.decode(z, return_dict=return_dict)
1467
+ return dec
models/reconstruction_model.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import math
6
+ import numpy as np
7
+
8
+ from utils import zero_init, EMANorm, create_rays
9
+
10
+ import einops
11
+
12
+ from .render import gaussian_render
13
+
14
+ from utils import quaternion_to_matrix
15
+
16
+ def inverse_sigmoid(x):
17
+ if type(x) == torch.Tensor:
18
+ return torch.log(x/(1-x))
19
+ else:
20
+ return math.log(x/(1-x))
21
+
22
+ def inverse_softplus(x, beta=1):
23
+ if type(x) == torch.Tensor:
24
+ return (torch.exp(beta * x) - 1).log() / beta
25
+ else:
26
+ return math.log((math.exp(beta * x) - 1)) / beta
27
+
28
+ import copy
29
+
30
+ import math
31
+ import torch
32
+ import torch.nn as nn
33
+ import numpy as np
34
+
35
+ from .autoencoder_kl_wan import WanCausalConv3d, WanRMS_norm, unpatchify
36
+
37
+
38
+ class WANDecoderPixelAligned3DGSReconstructionModel(nn.Module):
39
+ def __init__(self,
40
+ vae_model,
41
+ feat_dim,
42
+ # num_remove_decoder_up_blocks=0,
43
+ # num_points_per_pixel=4,
44
+ use_network_checkpointing=True,
45
+ use_render_checkpointing=True
46
+ ):
47
+ super().__init__()
48
+
49
+ self.decoder = copy.deepcopy(vae_model.decoder).requires_grad_(True)
50
+ self.post_quant_conv = copy.deepcopy(vae_model.post_quant_conv).requires_grad_(True)
51
+
52
+ self.extra_conv_in = WanCausalConv3d(feat_dim, self.decoder.conv_in.weight.shape[0], 3, padding=1)
53
+
54
+ time_pad = self.extra_conv_in._padding[4]
55
+ self.extra_conv_in.padding = (0, self.extra_conv_in._padding[2], self.extra_conv_in._padding[0])
56
+ self.extra_conv_in._padding = (0, 0, 0, 0, 0, 0)
57
+ self.extra_conv_in.weight = torch.nn.Parameter(self.extra_conv_in.weight[:, :, time_pad:].clone())
58
+
59
+ with torch.no_grad():
60
+ self.extra_conv_in.weight.data.zero_()
61
+ self.extra_conv_in.bias.data.zero_()
62
+
63
+ # remove one block
64
+ # self.decoder.up_blocks = self.decoder.up_blocks[:-1]
65
+ dims = [self.decoder.dim * u for u in [self.decoder.dim_mult[-1]] + self.decoder.dim_mult[::-1]]
66
+ # self.decoder.up_blocks[-1].upsampler.mode = None
67
+ # self.decoder.up_blocks[-1].upsampler.resample = nn.Identity()
68
+ # self.decoder.up_blocks[-1].avg_shortcut = None
69
+
70
+ self.decoder.norm_out = WanRMS_norm(dims[-1], images=False, bias=False)
71
+ self.decoder.conv_out = nn.Identity()
72
+
73
+ # add ema_norm for vae
74
+ # for i_level in reversed(range(len(self.decoder.up_blocks))):
75
+ # if self.decoder.up_blocks[i_level].upsampler is not None:
76
+ # self.decoder.up_blocks[i_level].upsampler.resample = nn.Sequential(
77
+ # self.decoder.up_blocks[i_level].upsampler.resample,
78
+ # )
79
+
80
+ self.patch_size = vae_model.config.patch_size
81
+ # assert dims[-1] % 4 == 0
82
+ self.gs_head = PixelAligned3DGS(dims[-1], num_points_per_pixel=2)
83
+
84
+ del self.decoder.up_blocks[0].upsampler.time_conv
85
+ del self.decoder.up_blocks[1].upsampler.time_conv
86
+
87
+ self.decoder.conv_out = nn.Identity()
88
+
89
+ self.network_checkpointing = use_network_checkpointing
90
+ self.render_checkpointing = use_render_checkpointing
91
+
92
+ def decode(self, feats, z):
93
+ ## conv1
94
+ x = self.decoder.conv_in(self.post_quant_conv(z)) + self.extra_conv_in(feats)
95
+
96
+ ## middle
97
+ if self.network_checkpointing and torch.is_grad_enabled():
98
+ x = torch.utils.checkpoint.checkpoint(self.decoder.mid_block, x, None, [0], use_reentrant=False)
99
+ else:
100
+ x = self.decoder.mid_block(x, None, [0])
101
+
102
+ ## upsamples
103
+ for i, up_block in enumerate(self.decoder.up_blocks):
104
+ if self.network_checkpointing and torch.is_grad_enabled():
105
+ x = torch.utils.checkpoint.checkpoint(up_block, x, None, [0], True, use_reentrant=False)
106
+ else:
107
+ x = up_block(x, None, [0], first_chunk=True)
108
+
109
+ # head
110
+ x = self.decoder.norm_out(x)
111
+ x = self.decoder.nonlinearity(x)
112
+ x = self.decoder.conv_out(x)
113
+
114
+ # if self.patch_size is not None:
115
+ # x = unpatchify(x, patch_size=self.patch_size)
116
+
117
+ return x
118
+
119
+ def forward(self, feats, z, cameras):
120
+
121
+ x = self.decode(feats, z).squeeze(2)
122
+
123
+ gaussian_params = self.gs_head(x, cameras.flatten(0, 1)).unflatten(0, (cameras.shape[0], cameras.shape[1]))
124
+
125
+ return gaussian_params
126
+
127
+ # def forward(self, images, cameras, scene_chunk_lens):
128
+
129
+ # x, z, feats = self.encode(images)
130
+
131
+ # return self.reconstruct(x, z, feats, cameras, scene_chunk_lens)
132
+
133
+ @torch.amp.autocast(device_type='cuda', enabled=False)
134
+ def render(self, gaussian_params, camerass, height, width, bg_mode='random'):
135
+
136
+ camerass = camerass.to(torch.float32)
137
+
138
+ test_c2ws = torch.eye(4, device=camerass.device)[None][None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
139
+ test_c2ws[:, :, :3, :3] = quaternion_to_matrix(camerass[:, :, :4])
140
+ test_c2ws[:, :, :3, 3] = camerass[:, :, 4:7]
141
+
142
+ test_intr = torch.eye(3, device=camerass.device)[None, None].repeat(camerass.shape[0], camerass.shape[1], 1, 1).float()
143
+ fx, fy, cx, cy = camerass[:, :, 7:11].split([1, 1, 1, 1], dim=-1)
144
+
145
+ test_intr = torch.cat([fx * width, fy * height, cx * width, cy * height], dim=-1)
146
+
147
+ return gaussian_render(gaussian_params, test_c2ws, test_intr, width, height, use_checkpoint=self.render_checkpointing, sh_degree=self.gs_head.sh_degree, bg_mode=bg_mode)
148
+
149
+ from torch.autograd import Function
150
+
151
+ class _trunc_exp(Function):
152
+ @staticmethod
153
+ def forward(ctx, x):
154
+ ctx.save_for_backward(x)
155
+ return torch.exp(x)
156
+
157
+ @staticmethod
158
+ def backward(ctx, g):
159
+ x = ctx.saved_tensors[0]
160
+ return g * torch.exp(x.clamp(-10, 10))
161
+
162
+ trunc_exp = _trunc_exp.apply
163
+
164
+ class PixelAligned3DGS(nn.Module):
165
+ def __init__(
166
+ self,
167
+ embed_dim,
168
+ sh_degree=2,
169
+ use_mask=False,
170
+ scale_range=(0, 16), # related to pixel size
171
+ num_points_per_pixel=1,
172
+ ):
173
+ super().__init__()
174
+
175
+ self.sh_degree = sh_degree
176
+
177
+ # sh, uv_offset, depth, opacity, scales, rotations
178
+ # TODO: handle different sh_degree
179
+ self.gaussian_channels = [3 * (self.sh_degree + 1) ** 2, 2, 1, 1, 3, 4, (1 if use_mask else 0)]
180
+
181
+ self.gs_proj = nn.Conv2d(embed_dim, num_points_per_pixel * sum(self.gaussian_channels), 3, 1, 1)
182
+ self.register_buffer("lrs_mul", torch.Tensor(
183
+ [1] * 3 + # sh 0
184
+ [0.5] * 3 * ((self.sh_degree + 1) ** 2 - 1) + # other sh
185
+ [0.01] * 2 + # uv_offset
186
+ [1] * 1 + # depth
187
+ [1] * 1 + # opacity
188
+ [1] * 3 + # scales
189
+ [1] * 4 + # rotations
190
+ [0.1] * (1 if use_mask else 0) # mask
191
+ ).repeat(num_points_per_pixel), persistent=True)
192
+
193
+ self.lrs_mul = self.lrs_mul / self.lrs_mul.max()
194
+
195
+ self.use_mask = use_mask
196
+
197
+ self.scale_range = scale_range
198
+
199
+ with torch.no_grad():
200
+ self.gs_proj.weight.data.zero_()
201
+ self.gs_proj.bias = nn.Parameter(torch.Tensor(
202
+ [0.0] * 3 * (self.sh_degree + 1) ** 2 + # sh
203
+ [0.0] * 2 + # uv_offset
204
+ [math.log(1)] * 1 + # depth
205
+ # [inverse_softplus(1)] * 1 + # depth
206
+ [inverse_sigmoid(0.1)] * 1 + # opacity
207
+ [inverse_sigmoid((1 - scale_range[0]) / (scale_range[1] - scale_range[0]))] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
208
+ # [inverse_softplus(0.005)] * 3 + # scales (default: 1 hence the gaussian scale is equal to pixel size)
209
+ [1., 0, 0, 0] + # rotations
210
+ [inverse_sigmoid(0.9)] * (1 if use_mask else 0) # mask (default: 0.9)
211
+ ).repeat(num_points_per_pixel) / self.lrs_mul)
212
+
213
+ self.num_points_per_pixel = num_points_per_pixel
214
+
215
+ @torch.amp.autocast(device_type='cuda', enabled=False)
216
+ def forward(self, x, cameras):
217
+
218
+ x = x.to(torch.float32)
219
+ cameras = cameras.to(torch.float32)
220
+
221
+ BN, _, h, w = x.shape
222
+
223
+ local_gaussian_params = F.conv2d(x, self.gs_proj.weight * self.lrs_mul[:, None, None, None], self.gs_proj.bias * self.lrs_mul, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))
224
+ # local_gaussian_params = F.conv2d(x, self.gs_proj.weight, self.gs_proj.bias, stride=1, padding=1).unflatten(1, (self.num_points_per_pixel, -1))
225
+
226
+ # batch * n_frame, num_points_per_pixel, c, h, w -> batch * n_frame, num_points_per_pixel, h, w, c
227
+ local_gaussian_params = local_gaussian_params.permute(0, 1, 3, 4, 2)
228
+
229
+ features, uv_offset, depth, opacity, scales, rotations, mask = local_gaussian_params.split(self.gaussian_channels, dim=-1)
230
+
231
+ rays_o, rays_d = create_rays(cameras[:, None].repeat(1, self.num_points_per_pixel, 1), uv_offset=uv_offset, h=h, w=w)
232
+
233
+ depth = trunc_exp(depth)
234
+ # depth = F.softplus(depth, beta=1)
235
+ xyz = (rays_o + depth * rays_d)
236
+
237
+ # features = features.unflatten(-1, (-1, 3))
238
+
239
+ opacity = torch.sigmoid(opacity)
240
+ if self.use_mask:
241
+ if torch.is_grad_enabled():
242
+ mask = torch.sigmoid(mask)
243
+ hard_mask = (mask > torch.rand_like(mask)).float()
244
+ opacity = opacity * (mask + (hard_mask - mask).detach())
245
+ else:
246
+ mask = torch.sigmoid(mask)
247
+ hard_mask = (mask > torch.rand_like(mask)).float()
248
+ opacity = opacity * hard_mask
249
+
250
+ fx, fy = cameras[:, 7:9].split([1, 1], dim=-1)
251
+ fx, fy = fx / w, fy / h
252
+ pixel_size = torch.sqrt(fx.pow(2) + fy.pow(2))[:, None, None, None] * depth
253
+ scales = (torch.sigmoid(scales) * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]) * pixel_size
254
+ # scales = F.softplus(scales, beta=1)
255
+
256
+ # It’s not required to be normalized for gspalt rasterization?
257
+ rotations = torch.nn.functional.normalize(rotations, dim=-1)
258
+
259
+ gaussian_params = torch.cat([xyz, opacity, scales, rotations, features], dim=-1)
260
+
261
+ return gaussian_params
models/render.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from xml.dom.minidom import Notation
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from gsplat import rasterization
10
+
11
+ # torch.backends.cuda.preferred_linalg_library(backend="magma")
12
+
13
+ """"
14
+ modified from https://github.com/arthurhero/Long-LRM/blob/main/model/llrm.py
15
+ """
16
+ class GaussianRendererWithCheckpoint(torch.autograd.Function):
17
+ @staticmethod
18
+ def render(xyz, feature, scale, rotation, opacity, test_c2w, test_intr,
19
+ W, H, sh_degree, near_plane, far_plane, backgrounds):
20
+ test_w2c = test_c2w.float().inverse().unsqueeze(0) # (1, 4, 4)
21
+ test_intr_i = torch.zeros(3, 3).to(test_intr.device)
22
+ test_intr_i[0, 0] = test_intr[0]
23
+ test_intr_i[1, 1] = test_intr[1]
24
+ test_intr_i[0, 2] = test_intr[2]
25
+ test_intr_i[1, 2] = test_intr[3]
26
+ test_intr_i[2, 2] = 1
27
+ test_intr_i = test_intr_i.unsqueeze(0) # (1, 3, 3)
28
+ rendering, alpha, _ = rasterization(xyz, rotation, scale, opacity, feature,
29
+ test_w2c, test_intr_i, W, H, sh_degree=sh_degree,
30
+ near_plane=near_plane, far_plane=far_plane,
31
+ render_mode="RGB+D",
32
+ backgrounds=backgrounds[None],
33
+ rasterize_mode='classic') # (1, H, W, 4)
34
+ # rendering[..., 3:] = rendering[..., 3:] + far_plane * (1 - alpha)
35
+ return rendering
36
+
37
+ @staticmethod
38
+ def forward(ctx, xyz, feature, scale, rotation, opacity, test_c2ws, test_intr,
39
+ W, H, sh_degree, near_plane, far_plane, backgrounds):
40
+ ctx.save_for_backward(xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds)
41
+ ctx.W = W
42
+ ctx.H = H
43
+ ctx.sh_degree = sh_degree
44
+ ctx.near_plane = near_plane
45
+ ctx.far_plane = far_plane
46
+ with torch.no_grad():
47
+ V, _ = test_intr.shape
48
+ renderings = torch.zeros(V, H, W, 4).to(xyz.device)
49
+ alphas = torch.rand(V, device=xyz.device)
50
+ for iv in range(V):
51
+ rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity,
52
+ test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
53
+ renderings[iv:iv+1] = rendering
54
+
55
+ renderings = renderings.requires_grad_()
56
+ return renderings
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds = ctx.saved_tensors
61
+ xyz = xyz.detach().requires_grad_()
62
+ feature = feature.detach().requires_grad_()
63
+ scale = scale.detach().requires_grad_()
64
+ rotation = rotation.detach().requires_grad_()
65
+ opacity = opacity.detach().requires_grad_()
66
+ W = ctx.W
67
+ H = ctx.H
68
+ sh_degree = ctx.sh_degree
69
+ near_plane = ctx.near_plane
70
+ far_plane = ctx.far_plane
71
+ with torch.enable_grad():
72
+ V, _ = test_intr.shape
73
+ for iv in range(V):
74
+ rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity,
75
+ test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
76
+ rendering.backward(grad_output[iv:iv+1])
77
+
78
+ return xyz.grad, feature.grad, scale.grad, rotation.grad, opacity.grad, None, None, None, None, None, None, None, None
79
+
80
+ def gaussian_render(gaussian_params, test_c2ws, test_intr, W, H, near_plane=0.01, far_plane=1000, use_checkpoint=False, sh_degree=0, bg_mode='random'):
81
+
82
+ if not torch.is_grad_enabled():
83
+ use_checkpoint = False
84
+
85
+ # opengl2colmap, see https://github.com/imlixinyang/Director3D/blob/main/modules/renderers/gaussians_renderer.py
86
+ test_c2ws[:, :, :3, 1:3] *= -1
87
+
88
+ device = test_intr.device
89
+ B, V, _ = test_intr.shape
90
+
91
+ renderings = []
92
+
93
+ for ib in range(B):
94
+ if bg_mode == 'random':
95
+ backgrounds = torch.rand(V, 3).to(device)
96
+ elif bg_mode == 'white':
97
+ backgrounds = torch.ones(V, 3).to(device)
98
+ elif bg_mode == 'black':
99
+ backgrounds = torch.zeros(V, 3).to(device)
100
+ else:
101
+ raise ValueError(f"Invalid background mode: {bg_mode}")
102
+
103
+ xyz_i, opacity_i, scale_i, rotation_i, feature_i = gaussian_params[ib].float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
104
+
105
+ opacity_i = opacity_i.squeeze(-1)
106
+ feature_i = feature_i.reshape(-1, (sh_degree + 1)**2, 3)
107
+
108
+ if use_checkpoint:
109
+
110
+ renderings.append(GaussianRendererWithCheckpoint.apply(xyz_i, feature_i, scale_i, rotation_i, opacity_i, test_c2ws[ib], test_intr[ib], W, H, sh_degree, near_plane, far_plane, backgrounds))
111
+
112
+ else:
113
+ rendering = torch.zeros(V, H, W, 4).to(device)
114
+ for iv in range(V):
115
+ rendering[iv:iv+1] = GaussianRendererWithCheckpoint.render(xyz_i, feature_i, scale_i, rotation_i, opacity_i,
116
+ test_c2ws[ib][iv], test_intr[ib][iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
117
+
118
+ # test_w2c_i = test_c2ws[ib].float().inverse() # (V, 4, 4)
119
+ # test_intr_i = torch.zeros(V, 3, 3).to(device)
120
+ # test_intr_i[:, 0, 0] = test_intr[ib, :, 0]
121
+ # test_intr_i[:, 1, 1] = test_intr[ib, :, 1]
122
+ # test_intr_i[:, 0, 2] = test_intr[ib, :, 2]
123
+ # test_intr_i[:, 1, 2] = test_intr[ib, :, 3]
124
+ # test_intr_i[:, 2, 2] = 1
125
+
126
+ # # print(backgrounds.shape)
127
+ # rendering, _, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
128
+ # test_w2c_i, test_intr_i, W, H, sh_degree=sh_degree,
129
+ # near_plane=near_plane, far_plane=far_plane,
130
+ # render_mode="RGB+D",
131
+ # backgrounds=backgrounds,
132
+ # rasterize_mode='classic') # (V, H, W, 3)
133
+ renderings.append(rendering)
134
+
135
+ renderings = torch.stack(renderings, dim=0).permute(0, 1, 4, 2, 3).contiguous() # (B, 3, V, H, W)
136
+ rgb = renderings[:, :, :3].mul_(2).add_(-1).clamp(-1, 1)
137
+ depth = renderings[:, :, 3:]
138
+ return rgb, depth
models/transformer_wan.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import Attention
28
+ from diffusers.models.cache_utils import CacheMixin
29
+ from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+
33
+ try:
34
+ from sageattention import sageattn
35
+ except ImportError:
36
+ sageattn = None
37
+
38
+ class FP32LayerNorm(nn.LayerNorm):
39
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
40
+ return F.layer_norm(
41
+ inputs,
42
+ self.normalized_shape,
43
+ self.weight if self.weight is not None else None,
44
+ self.bias if self.bias is not None else None,
45
+ self.eps,
46
+ ).to(inputs.dtype)
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ class WanAttnProcessor2_0:
51
+ def __init__(self):
52
+ if not hasattr(F, "scaled_dot_product_attention"):
53
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
54
+
55
+ def __call__(
56
+ self,
57
+ attn: Attention,
58
+ hidden_states: torch.Tensor,
59
+ encoder_hidden_states: Optional[torch.Tensor] = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ rotary_emb: Optional[torch.Tensor] = None,
62
+ ) -> torch.Tensor:
63
+ encoder_hidden_states_img = None
64
+ if attn.add_k_proj is not None:
65
+ # 512 is the context length of the text encoder, hardcoded for now
66
+ image_context_length = encoder_hidden_states.shape[1] - 512
67
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
68
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
69
+ if encoder_hidden_states is None:
70
+ encoder_hidden_states = hidden_states
71
+
72
+ query = attn.to_q(hidden_states)
73
+ key = attn.to_k(encoder_hidden_states)
74
+ value = attn.to_v(encoder_hidden_states)
75
+
76
+ if attn.norm_q is not None:
77
+ query = attn.norm_q(query).to(hidden_states.dtype)
78
+ if attn.norm_k is not None:
79
+ key = attn.norm_k(key).to(hidden_states.dtype)
80
+
81
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84
+
85
+ if rotary_emb is not None:
86
+
87
+ def apply_rotary_emb(
88
+ hidden_states: torch.Tensor,
89
+ freqs_cos: torch.Tensor,
90
+ freqs_sin: torch.Tensor,
91
+ ):
92
+ x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
93
+ x1, x2 = x[..., 0], x[..., 1]
94
+ cos = freqs_cos[..., 0::2]
95
+ sin = freqs_sin[..., 1::2]
96
+ out = torch.empty_like(hidden_states)
97
+ out[..., 0::2] = x1 * cos - x2 * sin
98
+ out[..., 1::2] = x1 * sin + x2 * cos
99
+ return out.type_as(hidden_states)
100
+
101
+ query = apply_rotary_emb(query, *rotary_emb)
102
+ key = apply_rotary_emb(key, *rotary_emb)
103
+
104
+ # I2V task
105
+ hidden_states_img = None
106
+ if encoder_hidden_states_img is not None:
107
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
108
+ key_img = attn.norm_added_k(key_img)
109
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
110
+
111
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
112
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
113
+
114
+ if sageattn is not None:
115
+ # Ensure kernels receive fp16/bf16 tensors under autocast
116
+ if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
117
+ target_dtype = torch.bfloat16
118
+ query = query.to(target_dtype)
119
+ key_img = key_img.to(target_dtype)
120
+ value_img = value_img.to(target_dtype)
121
+ hidden_states_img = sageattn(
122
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
123
+ )
124
+ else:
125
+ hidden_states_img = F.scaled_dot_product_attention(
126
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
127
+ )
128
+
129
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
130
+ hidden_states_img = hidden_states_img.type_as(query)
131
+
132
+ if sageattn is not None:
133
+ # print(query.dtype)
134
+ # Ensure kernels receive fp16/bf16 tensors under autocast
135
+ if torch.is_autocast_enabled() and query.dtype not in (torch.float16, torch.bfloat16):
136
+ target_dtype = torch.bfloat16
137
+ query = query.to(target_dtype)
138
+ key = key.to(target_dtype)
139
+ value = value.to(target_dtype)
140
+ hidden_states = sageattn(
141
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
142
+ )
143
+ else:
144
+ hidden_states = F.scaled_dot_product_attention(
145
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
146
+ )
147
+
148
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
149
+ hidden_states = hidden_states.type_as(query)
150
+
151
+ if hidden_states_img is not None:
152
+ hidden_states = hidden_states + hidden_states_img
153
+
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ hidden_states = attn.to_out[1](hidden_states)
156
+ return hidden_states
157
+
158
+
159
+ class WanImageEmbedding(torch.nn.Module):
160
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
161
+ super().__init__()
162
+
163
+ self.norm1 = FP32LayerNorm(in_features)
164
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
165
+ self.norm2 = FP32LayerNorm(out_features)
166
+ if pos_embed_seq_len is not None:
167
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
168
+ else:
169
+ self.pos_embed = None
170
+
171
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
172
+ if self.pos_embed is not None:
173
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
174
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
175
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
176
+
177
+ hidden_states = self.norm1(encoder_hidden_states_image)
178
+ hidden_states = self.ff(hidden_states)
179
+ hidden_states = self.norm2(hidden_states)
180
+ return hidden_states
181
+
182
+
183
+ class WanTimeTextImageEmbedding(nn.Module):
184
+ def __init__(
185
+ self,
186
+ dim: int,
187
+ time_freq_dim: int,
188
+ time_proj_dim: int,
189
+ text_embed_dim: int,
190
+ image_embed_dim: Optional[int] = None,
191
+ pos_embed_seq_len: Optional[int] = None,
192
+ ):
193
+ super().__init__()
194
+
195
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
196
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
197
+ self.act_fn = nn.SiLU()
198
+ self.time_proj = nn.Linear(dim, time_proj_dim)
199
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
200
+
201
+ self.image_embedder = None
202
+ if image_embed_dim is not None:
203
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
204
+
205
+ def forward(
206
+ self,
207
+ timestep: torch.Tensor,
208
+ encoder_hidden_states: torch.Tensor,
209
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
210
+ timestep_seq_len: Optional[int] = None,
211
+ ):
212
+ timestep = self.timesteps_proj(timestep)
213
+ if timestep_seq_len is not None:
214
+ timestep = timestep.unflatten(0, (1, timestep_seq_len))
215
+
216
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
217
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
218
+ timestep = timestep.to(time_embedder_dtype)
219
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
220
+ timestep_proj = self.time_proj(self.act_fn(temb))
221
+
222
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
223
+ if encoder_hidden_states_image is not None:
224
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
225
+
226
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
227
+
228
+
229
+ class WanRotaryPosEmbed(nn.Module):
230
+ def __init__(
231
+ self,
232
+ attention_head_dim: int,
233
+ patch_size: Tuple[int, int, int],
234
+ max_seq_len: int,
235
+ theta: float = 10000.0,
236
+ ):
237
+ super().__init__()
238
+
239
+ self.attention_head_dim = attention_head_dim
240
+ self.patch_size = patch_size
241
+ self.max_seq_len = max_seq_len
242
+
243
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
244
+ t_dim = attention_head_dim - h_dim - w_dim
245
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
246
+
247
+ freqs_cos = []
248
+ freqs_sin = []
249
+
250
+ for dim in [t_dim, h_dim, w_dim]:
251
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
252
+ dim,
253
+ max_seq_len,
254
+ theta,
255
+ use_real=True,
256
+ repeat_interleave_real=True,
257
+ freqs_dtype=freqs_dtype,
258
+ )
259
+ freqs_cos.append(freq_cos)
260
+ freqs_sin.append(freq_sin)
261
+
262
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
263
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
264
+
265
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
266
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
267
+ p_t, p_h, p_w = self.patch_size
268
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
269
+
270
+ split_sizes = [
271
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
272
+ self.attention_head_dim // 3,
273
+ self.attention_head_dim // 3,
274
+ ]
275
+
276
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
277
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
278
+
279
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
280
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
281
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
282
+
283
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
284
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
285
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
286
+
287
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
288
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
289
+
290
+ return freqs_cos, freqs_sin
291
+
292
+
293
+ @maybe_allow_in_graph
294
+ class WanTransformerBlock(nn.Module):
295
+ def __init__(
296
+ self,
297
+ dim: int,
298
+ ffn_dim: int,
299
+ num_heads: int,
300
+ qk_norm: str = "rms_norm_across_heads",
301
+ cross_attn_norm: bool = False,
302
+ eps: float = 1e-6,
303
+ added_kv_proj_dim: Optional[int] = None,
304
+ ):
305
+ super().__init__()
306
+
307
+ # 1. Self-attention
308
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
309
+ self.attn1 = Attention(
310
+ query_dim=dim,
311
+ heads=num_heads,
312
+ kv_heads=num_heads,
313
+ dim_head=dim // num_heads,
314
+ qk_norm=qk_norm,
315
+ eps=eps,
316
+ bias=True,
317
+ cross_attention_dim=None,
318
+ out_bias=True,
319
+ processor=WanAttnProcessor2_0(),
320
+ )
321
+
322
+ # 2. Cross-attention
323
+ self.attn2 = Attention(
324
+ query_dim=dim,
325
+ heads=num_heads,
326
+ kv_heads=num_heads,
327
+ dim_head=dim // num_heads,
328
+ qk_norm=qk_norm,
329
+ eps=eps,
330
+ bias=True,
331
+ cross_attention_dim=None,
332
+ out_bias=True,
333
+ added_kv_proj_dim=added_kv_proj_dim,
334
+ added_proj_bias=True,
335
+ processor=WanAttnProcessor2_0(),
336
+ )
337
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
338
+
339
+ # 3. Feed-forward
340
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
341
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
342
+
343
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ encoder_hidden_states: torch.Tensor,
349
+ temb: torch.Tensor,
350
+ rotary_emb: torch.Tensor,
351
+ ) -> torch.Tensor:
352
+ if temb.ndim == 4:
353
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
354
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
355
+ self.scale_shift_table.unsqueeze(0) + temb
356
+ ).chunk(6, dim=2)
357
+ # batch_size, seq_len, 1, inner_dim
358
+ shift_msa = shift_msa.squeeze(2)
359
+ scale_msa = scale_msa.squeeze(2)
360
+ gate_msa = gate_msa.squeeze(2)
361
+ c_shift_msa = c_shift_msa.squeeze(2)
362
+ c_scale_msa = c_scale_msa.squeeze(2)
363
+ c_gate_msa = c_gate_msa.squeeze(2)
364
+ else:
365
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
366
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
367
+ self.scale_shift_table + temb
368
+ ).chunk(6, dim=1)
369
+
370
+ # print(hidden_states.dtype)
371
+
372
+ # 1. Self-attention
373
+ norm_hidden_states = (self.norm1(hidden_states).mul_(1 + scale_msa).add_(shift_msa))
374
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
375
+ hidden_states += attn_output * gate_msa
376
+ # hidden_states = hidden_states.type_as(hidden_states)
377
+
378
+ # print(hidden_states.dtype)
379
+
380
+ # 2. Cross-attention
381
+ norm_hidden_states = self.norm2(hidden_states)
382
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
383
+ hidden_states += attn_output
384
+
385
+ # print(hidden_states.dtype)
386
+
387
+ # 3. Feed-forward
388
+ norm_hidden_states = (self.norm3(hidden_states).mul_(1 + c_scale_msa).add_(c_shift_msa))
389
+ ff_output = self.ffn(norm_hidden_states)
390
+ hidden_states += ff_output.mul_(c_gate_msa)
391
+ # hidden_states = hidden_states.type_as(hidden_states)
392
+
393
+ return hidden_states
394
+
395
+
396
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
397
+ r"""
398
+ A Transformer model for video-like data used in the Wan model.
399
+
400
+ Args:
401
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
402
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
403
+ num_attention_heads (`int`, defaults to `40`):
404
+ Fixed length for text embeddings.
405
+ attention_head_dim (`int`, defaults to `128`):
406
+ The number of channels in each head.
407
+ in_channels (`int`, defaults to `16`):
408
+ The number of channels in the input.
409
+ out_channels (`int`, defaults to `16`):
410
+ The number of channels in the output.
411
+ text_dim (`int`, defaults to `512`):
412
+ Input dimension for text embeddings.
413
+ freq_dim (`int`, defaults to `256`):
414
+ Dimension for sinusoidal time embeddings.
415
+ ffn_dim (`int`, defaults to `13824`):
416
+ Intermediate dimension in feed-forward network.
417
+ num_layers (`int`, defaults to `40`):
418
+ The number of layers of transformer blocks to use.
419
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
420
+ Window size for local attention (-1 indicates global attention).
421
+ cross_attn_norm (`bool`, defaults to `True`):
422
+ Enable cross-attention normalization.
423
+ qk_norm (`bool`, defaults to `True`):
424
+ Enable query/key normalization.
425
+ eps (`float`, defaults to `1e-6`):
426
+ Epsilon value for normalization layers.
427
+ add_img_emb (`bool`, defaults to `False`):
428
+ Whether to use img_emb.
429
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
430
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
431
+ """
432
+
433
+ _supports_gradient_checkpointing = True
434
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
435
+ _no_split_modules = ["WanTransformerBlock"]
436
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
437
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
438
+ _repeated_blocks = ["WanTransformerBlock"]
439
+
440
+ @register_to_config
441
+ def __init__(
442
+ self,
443
+ patch_size: Tuple[int] = (1, 2, 2),
444
+ num_attention_heads: int = 40,
445
+ attention_head_dim: int = 128,
446
+ in_channels: int = 16,
447
+ out_channels: int = 16,
448
+ text_dim: int = 4096,
449
+ freq_dim: int = 256,
450
+ ffn_dim: int = 13824,
451
+ num_layers: int = 40,
452
+ cross_attn_norm: bool = True,
453
+ qk_norm: Optional[str] = "rms_norm_across_heads",
454
+ eps: float = 1e-6,
455
+ image_dim: Optional[int] = None,
456
+ added_kv_proj_dim: Optional[int] = None,
457
+ rope_max_seq_len: int = 1024,
458
+ pos_embed_seq_len: Optional[int] = None,
459
+ ) -> None:
460
+ super().__init__()
461
+
462
+ inner_dim = num_attention_heads * attention_head_dim
463
+ out_channels = out_channels or in_channels
464
+
465
+ # 1. Patch & position embedding
466
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
467
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
468
+
469
+ # 2. Condition embeddings
470
+ # image_embedding_dim=1280 for I2V model
471
+ self.condition_embedder = WanTimeTextImageEmbedding(
472
+ dim=inner_dim,
473
+ time_freq_dim=freq_dim,
474
+ time_proj_dim=inner_dim * 6,
475
+ text_embed_dim=text_dim,
476
+ image_embed_dim=image_dim,
477
+ pos_embed_seq_len=pos_embed_seq_len,
478
+ )
479
+
480
+ # 3. Transformer blocks
481
+ self.blocks = nn.ModuleList(
482
+ [
483
+ WanTransformerBlock(
484
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
485
+ )
486
+ for _ in range(num_layers)
487
+ ]
488
+ )
489
+
490
+ # 4. Output norm & projection
491
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
492
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
493
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
494
+
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ timestep: torch.LongTensor,
501
+ encoder_hidden_states: torch.Tensor,
502
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
503
+ return_dict: bool = True,
504
+ attention_kwargs: Optional[Dict[str, Any]] = None,
505
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
506
+ if attention_kwargs is not None:
507
+ attention_kwargs = attention_kwargs.copy()
508
+ lora_scale = attention_kwargs.pop("scale", 1.0)
509
+ else:
510
+ lora_scale = 1.0
511
+
512
+ if USE_PEFT_BACKEND:
513
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
514
+ scale_lora_layers(self, lora_scale)
515
+ else:
516
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
517
+ logger.warning(
518
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
519
+ )
520
+
521
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
522
+ p_t, p_h, p_w = self.config.patch_size
523
+ post_patch_num_frames = num_frames // p_t
524
+ post_patch_height = height // p_h
525
+ post_patch_width = width // p_w
526
+
527
+ rotary_emb = self.rope(hidden_states)
528
+
529
+ hidden_states = self.patch_embedding(hidden_states)
530
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
531
+
532
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
533
+ if timestep.ndim == 2:
534
+ ts_seq_len = timestep.shape[1]
535
+ timestep = timestep.flatten() # batch_size * seq_len
536
+ else:
537
+ ts_seq_len = None
538
+
539
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
540
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
541
+ )
542
+ if ts_seq_len is not None:
543
+ # batch_size, seq_len, 6, inner_dim
544
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
545
+ else:
546
+ # batch_size, 6, inner_dim
547
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
548
+
549
+ if encoder_hidden_states_image is not None:
550
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
551
+
552
+ if True:
553
+ encoder_hidden_states = encoder_hidden_states.to(torch.bfloat16)
554
+ timestep_proj = timestep_proj.to(torch.bfloat16)
555
+ rotary_emb = [rotary_emb[0].to(torch.bfloat16), rotary_emb[1].to(torch.bfloat16)]
556
+ hidden_states = hidden_states.to(torch.bfloat16)
557
+
558
+ # 4. Transformer blocks
559
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
560
+ for block in self.blocks:
561
+ hidden_states = self._gradient_checkpointing_func(
562
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
563
+ )
564
+ else:
565
+ for block in self.blocks:
566
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
567
+
568
+ # 5. Output norm, projection & unpatchify
569
+ if temb.ndim == 3:
570
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
571
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
572
+ shift = shift.squeeze(2)
573
+ scale = scale.squeeze(2)
574
+ else:
575
+ # batch_size, inner_dim
576
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
577
+
578
+ # Move the shift and scale tensors to the same device as hidden_states.
579
+ # When using multi-GPU inference via accelerate these will be on the
580
+ # first device rather than the last device, which hidden_states ends up
581
+ # on.
582
+ shift = shift.to(hidden_states.device)
583
+ scale = scale.to(hidden_states.device)
584
+
585
+ hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
586
+ hidden_states = self.proj_out(hidden_states)
587
+
588
+ hidden_states = hidden_states.reshape(
589
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
590
+ )
591
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
592
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
593
+
594
+ if USE_PEFT_BACKEND:
595
+ # remove `lora_scale` from each PEFT layer
596
+ unscale_lora_layers(self, lora_scale)
597
+
598
+ if not return_dict:
599
+ return (output,)
600
+
601
+ return Transformer2DModelOutput(sample=output)
quant.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Tuple
3
+ import copy
4
+ import torch
5
+ import tqdm
6
+
7
+
8
+ def cleanup_memory():
9
+ gc.collect()
10
+ torch.cuda.empty_cache()
11
+
12
+
13
+ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
14
+ """Quantize a tensor using per-tensor static scaling factor.
15
+ Args:
16
+ tensor: The input tensor.
17
+ """
18
+ finfo = torch.finfo(torch.float8_e4m3fn)
19
+ # Calculate the scale as dtype max divided by absmax.
20
+ # Since .abs() creates a new tensor, we use aminmax to get
21
+ # the min and max first and then calculate the absmax.
22
+ if tensor.numel() == 0:
23
+ # Deal with empty tensors (triggered by empty MoE experts)
24
+ min_val, max_val = (
25
+ torch.tensor(-16.0, dtype=tensor.dtype),
26
+ torch.tensor(16.0, dtype=tensor.dtype),
27
+ )
28
+ else:
29
+ min_val, max_val = tensor.aminmax()
30
+ amax = torch.maximum(min_val.abs(), max_val.abs())
31
+ scale = finfo.max / amax.clamp(min=1e-12)
32
+ # scale and clamp the tensor to bring it to
33
+ # the representative range of float8 data type
34
+ # (as default cast is unsaturated)
35
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
36
+ # Return both float8 data and the inverse scale (as float),
37
+ # as both required as inputs to torch._scaled_mm
38
+ qweight = qweight.to(torch.float8_e4m3fn)
39
+ scale = scale.float().reciprocal()
40
+ return qweight, scale
41
+
42
+
43
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
44
+ """Quantizes a floating-point tensor to FP8 (E4M3 format) using static scaling.
45
+
46
+ Performs uniform quantization of the input tensor by:
47
+ 1. Scaling the tensor values using the provided inverse scale factor
48
+ 2. Clamping values to the representable range of FP8 E4M3 format
49
+ 3. Converting to FP8 data type
50
+
51
+ Args:
52
+ tensor (torch.Tensor): Input tensor to be quantized (any floating-point dtype)
53
+ inv_scale (float): Inverse of the quantization scale factor (1/scale)
54
+ (Must be pre-calculated based on tensor statistics)
55
+
56
+ Returns:
57
+ torch.Tensor: Quantized tensor in torch.float8_e4m3fn format
58
+
59
+ Note:
60
+ - Uses the E4M3 format (4 exponent bits, 3 mantissa bits, no infinity/nan)
61
+ - This is a static quantization (scale factor must be pre-determined)
62
+ - For dynamic quantization, see per_tensor_quantize()
63
+ """
64
+ finfo = torch.finfo(torch.float8_e4m3fn)
65
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
66
+ return qweight.to(torch.float8_e4m3fn)
67
+
68
+
69
+ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
70
+ """Performs FP8 GEMM (General Matrix Multiplication) operation with optional native hardware support.
71
+ Args:
72
+ A (torch.Tensor): Input tensor A (FP8 or other dtype)
73
+ A_scale (torch.Tensor/float): Scale factor for tensor A
74
+ B (torch.Tensor): Input tensor B (FP8 or other dtype)
75
+ B_scale (torch.Tensor/float): Scale factor for tensor B
76
+ bias (torch.Tensor/None): Optional bias tensor
77
+ out_dtype (torch.dtype): Output data type
78
+ native_fp8_support (bool): Whether to use hardware-accelerated FP8 operations
79
+
80
+ Returns:
81
+ torch.Tensor: Result of GEMM operation
82
+ """
83
+ if A.numel() == 0:
84
+ # Deal with empty tensors (triggeted by empty MoE experts)
85
+ return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
86
+
87
+ if native_fp8_support:
88
+ need_reshape = A.dim() == 3
89
+ if need_reshape:
90
+ batch_size = A.shape[0]
91
+ A_input = A.reshape(-1, A.shape[-1])
92
+ else:
93
+ batch_size = None
94
+ A_input = A
95
+ output = torch._scaled_mm(
96
+ A_input,
97
+ B.t(),
98
+ out_dtype=out_dtype,
99
+ scale_a=A_scale,
100
+ scale_b=B_scale,
101
+ bias=bias,
102
+ )
103
+ if need_reshape:
104
+ output = output.reshape(
105
+ batch_size, output.shape[0] // batch_size, output.shape[1]
106
+ )
107
+ else:
108
+ output = torch.nn.functional.linear(
109
+ A.to(out_dtype) * A_scale,
110
+ B.to(out_dtype) * B_scale.to(out_dtype),
111
+ bias=bias,
112
+ )
113
+
114
+ return output
115
+
116
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
117
+ if "." in name:
118
+ parent_name = name.rsplit(".", 1)[0]
119
+ child_name = name[len(parent_name) + 1:]
120
+ parent = model.get_submodule(parent_name)
121
+ else:
122
+ parent_name = ""
123
+ parent = model
124
+ child_name = name
125
+ setattr(parent, child_name, new_module)
126
+
127
+
128
+ # Class responsible for quantizing weights
129
+ class FP8DynamicLinear(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ weight: torch.Tensor,
133
+ weight_scale: torch.Tensor,
134
+ bias: torch.nn.Parameter,
135
+ native_fp8_support: bool = False,
136
+ dtype: torch.dtype = torch.bfloat16,
137
+ ):
138
+ super().__init__()
139
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
140
+ self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
141
+ self.bias = bias
142
+ self.native_fp8_support = native_fp8_support
143
+ self.dtype = dtype
144
+
145
+ # @torch.compile
146
+ def forward(self, x):
147
+ if x.dtype !=self.dtype:
148
+ x = x.to(self.dtype)
149
+ qinput, x_scale = per_tensor_quantize(x)
150
+ output = fp8_gemm(
151
+ A=qinput,
152
+ A_scale=x_scale,
153
+ B=self.weight,
154
+ B_scale=self.weight_scale,
155
+ bias=self.bias,
156
+ out_dtype=x.dtype,
157
+ native_fp8_support=self.native_fp8_support,
158
+ )
159
+ return output
160
+
161
+
162
+ def FluxFp8GeMMProcessor(model: torch.nn.Module):
163
+ """Processes a PyTorch model to convert eligible Linear layers to FP8 precision.
164
+
165
+ This function performs the following operations:
166
+ 1. Checks for native FP8 support on the current GPU
167
+ 2. Identifies target Linear layers in transformer blocks
168
+ 3. Quantizes weights to FP8 format
169
+ 4. Replaces original Linear layers with FP8DynamicLinear versions
170
+ 5. Performs memory cleanup
171
+
172
+ Args:
173
+ model (torch.nn.Module): The neural network model to be processed.
174
+ Should contain transformer blocks with Linear layers.
175
+ """
176
+ native_fp8_support = (
177
+ torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
178
+ )
179
+ named_modules = list(model.named_modules())
180
+ for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
181
+ if isinstance(linear, torch.nn.Linear) and "blocks" in name:
182
+ quant_weight, weight_scale = per_tensor_quantize(linear.weight)
183
+ bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
184
+ quant_linear = FP8DynamicLinear(
185
+ weight=quant_weight,
186
+ weight_scale=weight_scale,
187
+ bias=bias,
188
+ native_fp8_support=native_fp8_support,
189
+ dtype=linear.weight.dtype
190
+ )
191
+ replace_module(model, name, quant_linear)
192
+ del linear.weight
193
+ del linear.bias
194
+ del linear
195
+ cleanup_memory()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ triton==3.2.0
4
+ transformers==4.57.0
5
+ omegaconf==2.3.0
6
+ ninja==1.13.0
7
+ numpy==2.2.6
8
+ einops==0.8.1
9
+ moviepy==1.0.3
10
+ opencv-python==4.12.0.88
11
+ av==15.1.0
12
+ plyfile==1.1.2
13
+ ftfy==6.3.1
14
+ flask==3.1.2
15
+ gradio==5.49.1
16
+ gsplat==1.5.2
17
+ accelerate==1.10.1
18
+ git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
19
+ git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
utils.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import importlib
8
+ from plyfile import PlyData, PlyElement
9
+
10
+ import copy
11
+
12
+ class EmbedContainer(nn.Module):
13
+ def __init__(self, tensor):
14
+ super().__init__()
15
+ self.tensor = nn.Parameter(tensor)
16
+
17
+ def forward(self):
18
+ return self.tensor
19
+
20
+ @torch.no_grad
21
+ def zero_init(module):
22
+ if type(module) is torch.nn.Conv2d or type(module) is torch.nn.Linear:
23
+ module.weight.zero_()
24
+ module.bias.zero_()
25
+ return module
26
+
27
+ def import_str(string):
28
+ # From https://github.com/CompVis/taming-transformers
29
+ module, cls = string.rsplit(".", 1)
30
+ return getattr(importlib.import_module(module, package=None), cls)
31
+
32
+ """
33
+ from https://github.com/Kai-46/minFM/blob/main/utils/ema.py
34
+ Exponential Moving Average (EMA) utilities for PyTorch models.
35
+
36
+ This module provides utilities for maintaining and updating EMA models,
37
+ which are commonly used to improve model stability and generalization
38
+ in training deep neural networks. It supports both regular tensors and
39
+ DTensors (from FSDP-wrapped models).
40
+ """
41
+ class EMA_FSDP:
42
+ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
43
+ self.decay = decay
44
+ self.shadow = {}
45
+ self._init_shadow(fsdp_module)
46
+
47
+ @torch.no_grad()
48
+ def _init_shadow(self, fsdp_module):
49
+ # 判断是否是FSDP模型
50
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
51
+ if isinstance(fsdp_module, FSDP):
52
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
53
+ for n, p in fsdp_module.module.named_parameters():
54
+ self.shadow[n] = p.detach().clone().float().cpu()
55
+ else:
56
+ for n, p in fsdp_module.named_parameters():
57
+ self.shadow[n] = p.detach().clone().float().cpu()
58
+
59
+ @torch.no_grad()
60
+ def update(self, fsdp_module):
61
+ d = self.decay
62
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
63
+ if isinstance(fsdp_module, FSDP):
64
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
65
+ for n, p in fsdp_module.module.named_parameters():
66
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
67
+ else:
68
+ for n, p in fsdp_module.named_parameters():
69
+ print(n, self.shadow[n])
70
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
71
+
72
+ # Optional helpers ---------------------------------------------------
73
+ def state_dict(self):
74
+ return self.shadow # picklable
75
+
76
+ def load_state_dict(self, sd):
77
+ self.shadow = {k: v.clone() for k, v in sd.items()}
78
+
79
+ def copy_to(self, fsdp_module):
80
+ # load EMA weights into an (unwrapped) copy of the generator
81
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
82
+ with FSDP.summon_full_params(fsdp_module, writeback=True):
83
+ for n, p in fsdp_module.module.named_parameters():
84
+ if n in self.shadow:
85
+ p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
86
+
87
+ def create_raymaps(cameras, h, w):
88
+ rays_o, rays_d = create_rays(cameras, h, w)
89
+ raymaps = torch.cat([rays_d, rays_o - (rays_o * rays_d).sum(dim=-1, keepdim=True) * rays_d], dim=-1)
90
+ return raymaps
91
+
92
+ # def create_raymaps(cameras, h, w):
93
+ # rays_o, rays_d = create_rays(cameras, h, w)
94
+ # raymaps = torch.cat([rays_d, torch.cross(rays_d, rays_o, dim=-1)], dim=-1)
95
+ # return raymaps
96
+
97
+ class EMANorm(nn.Module):
98
+ def __init__(self, beta):
99
+ super().__init__()
100
+ self.register_buffer('magnitude_ema', torch.ones([]))
101
+ self.beta = beta
102
+
103
+ def forward(self, x):
104
+ if self.training:
105
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
106
+ self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema.to(torch.float32), self.beta))
107
+ input_gain = self.magnitude_ema.rsqrt()
108
+ x = x.mul(input_gain)
109
+ return x
110
+
111
+ class TimestepEmbedding(nn.Module):
112
+ def __init__(self, dim, max_period=10000, time_factor: float = 1000.0, zero_weight: bool = True):
113
+ super().__init__()
114
+ self.max_period = max_period
115
+ self.time_factor = time_factor
116
+ self.dim = dim
117
+ if zero_weight:
118
+ self.weight = nn.Parameter(torch.zeros(dim))
119
+ else:
120
+ self.weight = None
121
+
122
+ def forward(self, t):
123
+ if self.weight is None:
124
+ return timestep_embedding(t, self.dim, self.max_period, self.time_factor)
125
+ else:
126
+ return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
127
+
128
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
129
+ def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
130
+ """
131
+ Create sinusoidal timestep embeddings.
132
+ :param t: a 1-D Tensor of N indices, one per batch element.
133
+ These may be fractional.
134
+ :param dim: the dimension of the output.
135
+ :param max_period: controls the minimum frequency of the embeddings.
136
+ :return: an (N, D) Tensor of positional embeddings.
137
+ """
138
+ t = time_factor * t
139
+ half = dim // 2
140
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
141
+
142
+ args = t[:, None].float() * freqs[None]
143
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144
+ if dim % 2:
145
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
146
+ if torch.is_floating_point(t):
147
+ embedding = embedding.to(t)
148
+ return embedding
149
+
150
+ def quaternion_to_matrix(quaternions):
151
+ """
152
+ Convert rotations given as quaternions to rotation matrices.
153
+ Args:
154
+ quaternions: quaternions with real part first,
155
+ as tensor of shape (..., 4).
156
+ Returns:
157
+ Rotation matrices as tensor of shape (..., 3, 3).
158
+ """
159
+ r, i, j, k = torch.unbind(quaternions, -1)
160
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
161
+
162
+ o = torch.stack(
163
+ (
164
+ 1 - two_s * (j * j + k * k),
165
+ two_s * (i * j - k * r),
166
+ two_s * (i * k + j * r),
167
+ two_s * (i * j + k * r),
168
+ 1 - two_s * (i * i + k * k),
169
+ two_s * (j * k - i * r),
170
+ two_s * (i * k - j * r),
171
+ two_s * (j * k + i * r),
172
+ 1 - two_s * (i * i + j * j),
173
+ ),
174
+ -1,
175
+ )
176
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
177
+
178
+ # from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
179
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Convert a unit quaternion to a standard form: one in which the real
182
+ part is non negative.
183
+
184
+ Args:
185
+ quaternions: Quaternions with real part first,
186
+ as tensor of shape (..., 4).
187
+
188
+ Returns:
189
+ Standardized quaternions as tensor of shape (..., 4).
190
+ """
191
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
192
+
193
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
194
+ """
195
+ Returns torch.sqrt(torch.max(0, x))
196
+ but with a zero subgradient where x is 0.
197
+ """
198
+ ret = torch.zeros_like(x)
199
+ positive_mask = x > 0
200
+ if torch.is_grad_enabled():
201
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
202
+ else:
203
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
204
+ return ret
205
+
206
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
207
+ """
208
+ Convert rotations given as rotation matrices to quaternions.
209
+
210
+ Args:
211
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
212
+
213
+ Returns:
214
+ quaternions with real part first, as tensor of shape (..., 4).
215
+ """
216
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
217
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
218
+
219
+ batch_dim = matrix.shape[:-2]
220
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
221
+ matrix.reshape(batch_dim + (9,)), dim=-1
222
+ )
223
+
224
+ q_abs = _sqrt_positive_part(
225
+ torch.stack(
226
+ [
227
+ 1.0 + m00 + m11 + m22,
228
+ 1.0 + m00 - m11 - m22,
229
+ 1.0 - m00 + m11 - m22,
230
+ 1.0 - m00 - m11 + m22,
231
+ ],
232
+ dim=-1,
233
+ )
234
+ )
235
+
236
+ # we produce the desired quaternion multiplied by each of r, i, j, k
237
+ quat_by_rijk = torch.stack(
238
+ [
239
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
240
+ # `int`.
241
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
242
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
243
+ # `int`.
244
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
245
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
246
+ # `int`.
247
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
248
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
249
+ # `int`.
250
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
251
+ ],
252
+ dim=-2,
253
+ )
254
+
255
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
256
+ # the candidate won't be picked.
257
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
258
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
259
+
260
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
261
+ # forall i; we pick the best-conditioned one (with the largest denominator)
262
+ indices = q_abs.argmax(dim=-1, keepdim=True)
263
+ expand_dims = list(batch_dim) + [1, 4]
264
+ gather_indices = indices.unsqueeze(-1).expand(expand_dims)
265
+ out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
266
+ return standardize_quaternion(out)
267
+
268
+ @torch.amp.autocast(device_type="cuda", enabled=False)
269
+ def normalize_cameras(cameras, return_meta=False, ref_w2c=None, T_norm=None, n_frame=None):
270
+ B, N = cameras.shape[:2]
271
+
272
+ c2ws = torch.zeros(B, N, 3, 4, device=cameras.device)
273
+
274
+ c2ws[..., :3, :3] = quaternion_to_matrix(cameras[..., 0:4])
275
+ c2ws[..., :, 3] = cameras[..., 4:7]
276
+
277
+ _c2ws = c2ws
278
+
279
+ ref_w2c = torch.inverse(matrix_to_square(_c2ws[:, :1])) if ref_w2c is None else ref_w2c
280
+ _c2ws = (ref_w2c.repeat(1, N, 1, 1) @ matrix_to_square(_c2ws))[..., :3, :]
281
+
282
+ if n_frame is not None:
283
+ T_norm = _c2ws[..., :n_frame, :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
284
+ else:
285
+ T_norm = _c2ws[..., :3, 3].norm(dim=-1).max(dim=1)[0][..., None, None] if T_norm is None else T_norm
286
+
287
+ _c2ws[..., :3, 3] = _c2ws[..., :3, 3] / (T_norm + 1e-2)
288
+
289
+ R = matrix_to_quaternion(_c2ws[..., :3, :3])
290
+ T = _c2ws[..., :3, 3]
291
+ cameras = torch.cat([R.float(), T.float(), cameras[..., 7:]], dim=-1)
292
+
293
+ if return_meta:
294
+ return cameras, ref_w2c, T_norm
295
+ else:
296
+ return cameras
297
+
298
+ def create_rays(cameras, h, w, uv_offset=None):
299
+ prefix_shape = cameras.shape[:-1]
300
+ cameras = cameras.flatten(0, -2)
301
+ device = cameras.device
302
+ N = cameras.shape[0]
303
+
304
+ c2w = torch.eye(4, device=device)[None].repeat(N, 1, 1)
305
+ c2w[:, :3, :3] = quaternion_to_matrix(cameras[:, :4])
306
+ c2w[:, :3, 3] = cameras[:, 4:7]
307
+
308
+ # fx, fy, cx, cy should be divided by original H, W
309
+ fx, fy, cx, cy = cameras[:, 7:].chunk(4, -1)
310
+
311
+ fx, cx = fx * w, cx * w
312
+ fy, cy = fy * h, cy * h
313
+
314
+ inds = torch.arange(0, h*w, device=device).expand(N, h*w)
315
+
316
+ i = inds % w + 0.5
317
+ j = torch.div(inds, w, rounding_mode='floor') + 0.5
318
+
319
+ u = i / cx + (uv_offset[..., 0].reshape(N, h*w) if uv_offset is not None else 0)
320
+ v = j / cy + (uv_offset[..., 1].reshape(N, h*w) if uv_offset is not None else 0)
321
+
322
+ zs = - torch.ones_like(i)
323
+ xs = - (u - 1) * cx / fx * zs
324
+ ys = (v - 1) * cy / fy * zs
325
+ directions = torch.stack((xs, ys, zs), dim=-1)
326
+
327
+ rays_d = F.normalize(directions @ c2w[:, :3, :3].transpose(-1, -2), dim=-1)
328
+
329
+ rays_o = c2w[..., :3, 3] # [B, 3]
330
+ rays_o = rays_o[..., None, :].expand_as(rays_d)
331
+
332
+ rays_o = rays_o.reshape(*prefix_shape, h, w, 3)
333
+ rays_d = rays_d.reshape(*prefix_shape, h, w, 3)
334
+
335
+ return rays_o, rays_d
336
+
337
+ def matrix_to_square(mat):
338
+ l = len(mat.shape)
339
+ if l==3:
340
+ return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],1,1).to(mat.device)],dim=1)
341
+ elif l==4:
342
+ return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
343
+
344
+ def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
345
+
346
+ sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
347
+
348
+ xyz, opacity, scale, rotation, feature = gaussians.float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)
349
+
350
+ means3D = xyz.contiguous().float()
351
+ opacity = opacity.contiguous().float()
352
+ scales = scale.contiguous().float()
353
+ rotations = rotation.contiguous().float()
354
+ shs = feature.contiguous().float() # [N, 1, 3]
355
+
356
+ # print(means3D.shape, opacity.shape, scales.shape, rotations.shape, shs.shape)
357
+
358
+ # prune by opacity
359
+ if opacity_threshold > 0:
360
+ mask = opacity[..., 0] >= opacity_threshold
361
+ means3D = means3D[mask]
362
+ opacity = opacity[mask]
363
+ scales = scales[mask]
364
+ rotations = rotations[mask]
365
+ shs = shs[mask]
366
+
367
+ print("Gaussian percentage: ", mask.float().mean())
368
+
369
+ if T_norm is not None:
370
+ means3D = means3D * T_norm.item()
371
+ scales = scales * T_norm.item()
372
+
373
+ # invert activation to make it compatible with the original ply format
374
+ opacity = torch.log(opacity/(1-opacity))
375
+ scales = torch.log(scales + 1e-8)
376
+
377
+ xyzs = means3D.detach() # .cpu().numpy()
378
+ f_dc = shs.detach().flatten(start_dim=1).contiguous() #.cpu().numpy()
379
+ opacities = opacity.detach() #.cpu().numpy()
380
+ scales = scales.detach() #.cpu().numpy()
381
+ rotations = rotations.detach() #.cpu().numpy()
382
+
383
+ l = ['x', 'y', 'z']
384
+ # All channels except the 3 DC
385
+ for i in range(f_dc.shape[1]):
386
+ l.append('f_dc_{}'.format(i))
387
+ l.append('opacity')
388
+ for i in range(scales.shape[1]):
389
+ l.append('scale_{}'.format(i))
390
+ for i in range(rotations.shape[1]):
391
+ l.append('rot_{}'.format(i))
392
+
393
+ dtype_full = [(attribute, 'f4') for attribute in l]
394
+
395
+ # 最优化方案:使用numpy的recarray直接创建
396
+ attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
397
+
398
+ # 使用recarray直接创建,避免循环和类型转换
399
+ elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
400
+ el = PlyElement.describe(elements, 'vertex')
401
+
402
+ print(path)
403
+
404
+ PlyData([el]).write(path)
405
+
406
+ # plydata = PlyData([el])
407
+
408
+ # vert = plydata["vertex"]
409
+ # sorted_indices = np.argsort(
410
+ # -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
411
+ # / (1 + np.exp(-vert["opacity"]))
412
+ # )
413
+ # buffer = BytesIO()
414
+ # for idx in sorted_indices:
415
+ # v = plydata["vertex"][idx]
416
+ # position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
417
+ # scales = np.exp(
418
+ # np.array(
419
+ # [v["scale_0"], v["scale_1"], v["scale_2"]],
420
+ # dtype=np.float32,
421
+ # )
422
+ # )
423
+ # rot = np.array(
424
+ # [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
425
+ # dtype=np.float32,
426
+ # )
427
+ # SH_C0 = 0.28209479177387814
428
+ # color = np.array(
429
+ # [
430
+ # 0.5 + SH_C0 * v["f_dc_0"],
431
+ # 0.5 + SH_C0 * v["f_dc_1"],
432
+ # 0.5 + SH_C0 * v["f_dc_2"],
433
+ # 1 / (1 + np.exp(-v["opacity"])),
434
+ # ]
435
+ # )
436
+ # buffer.write(position.tobytes())
437
+ # buffer.write(scales.tobytes())
438
+ # buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
439
+ # buffer.write(
440
+ # ((rot / np.linalg.norm(rot)) * 128 + 128)
441
+ # .clip(0, 255)
442
+ # .astype(np.uint8)
443
+ # .tobytes()
444
+ # )
445
+
446
+ # with open(path + '.splat', "wb") as f:
447
+ # f.write(buffer.getvalue())
448
+
449
+ @torch.amp.autocast(device_type="cuda", enabled=False)
450
+ def quaternion_slerp(
451
+ q0, q1, fraction, spin: int = 0, shortestpath: bool = True
452
+ ):
453
+ """Return spherical linear interpolation between two quaternions.
454
+ Args:
455
+ quat0: first quaternion
456
+ quat1: second quaternion
457
+ fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
458
+ spin: how much of an additional spin to place on the interpolation
459
+ shortestpath: whether to return the short or long path to rotation
460
+ """
461
+ d = (q0 * q1).sum(-1)
462
+ if shortestpath:
463
+ # invert rotation
464
+ d[d < 0.0] = -d[d < 0.0]
465
+ q1[d < 0.0] = q1[d < 0.0]
466
+
467
+ _d = d.clamp(0, 1.0)
468
+
469
+ # theta = torch.arccos(d) * fraction
470
+ # q2 = q1 - q0 * d
471
+ # q2 = q2 / (q2.norm(dim=-1) + 1e-10)
472
+
473
+ # return torch.cos(theta) * q0 + torch.sin(theta) * q2
474
+
475
+ angle = torch.acos(_d) + spin * math.pi
476
+ isin = 1.0 / (torch.sin(angle)+ 1e-10)
477
+ q0_ = q0 * (torch.sin((1.0 - fraction) * angle) * isin)[..., None]
478
+ q1_ = q1 * (torch.sin(fraction * angle) * isin)[..., None]
479
+
480
+ q = q0_ + q1_
481
+
482
+ q[angle < 1e-5] = q0[angle < 1e-5]
483
+ # q[fraction < 1e-5] = q0[fraction < 1e-5]
484
+ # q[fraction > 1 - 1e-5] = q1[fraction > 1 - 1e-5]
485
+ # q[(d.abs() - 1).abs() < 1e-5] = q0[(d.abs() - 1).abs() < 1e-5]
486
+
487
+ return q
488
+
489
+ def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]):
490
+ """
491
+ Args:
492
+ pose_a: first pose
493
+ pose_b: second pose
494
+ fraction
495
+ """
496
+
497
+ quat_a = pose_a[..., :4]
498
+ quat_b = pose_b[..., :4]
499
+
500
+ dot = torch.sum(quat_a * quat_b, dim=-1, keepdim=True)
501
+ quat_b = torch.where(dot < 0, -quat_b, quat_b)
502
+
503
+ quaternion = quaternion_slerp(quat_a, quat_b, fraction)
504
+ quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1)
505
+
506
+ T = (1 - fraction)[:, None] * pose_a[..., 4:] + fraction[:, None] * pose_b[..., 4:]
507
+ T = T + torch.randn_like(T) * noise_strengths[1]
508
+
509
+ new_pose = pose_a.clone()
510
+ new_pose[..., :4] = quaternion
511
+ new_pose[..., 4:] = T
512
+ return new_pose
513
+
514
+ def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]):
515
+ N, C = dense_cameras.shape
516
+ M = t.shape
517
+
518
+ left = torch.floor(t * (N-1)).long().clamp(0, N-2)
519
+ right = left + 1
520
+ fraction = t * (N-1) - left
521
+
522
+ a = torch.gather(dense_cameras, 0, left[..., None].repeat(1, C))
523
+ b = torch.gather(dense_cameras, 0, right[..., None].repeat(1, C))
524
+
525
+ new_pose = sample_from_two_pose(a[:, :7],
526
+ b[:, :7], fraction, noise_strengths=noise_strengths[:2])
527
+
528
+ new_ins = (1 - fraction)[:, None] * a[:, 7:] + fraction[:, None] * b[:, 7:]
529
+
530
+ return torch.cat([new_pose, new_ins], dim=1)
531
+