Upload 86 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- icons/move_backward.png +0 -0
- icons/move_forward.png +0 -0
- icons/move_left.png +0 -0
- icons/move_right.png +0 -0
- icons/not_move_backward.png +0 -0
- icons/not_move_forward.png +0 -0
- icons/not_move_left.png +0 -0
- icons/not_move_right.png +0 -0
- icons/not_turn_down.png +0 -0
- icons/not_turn_left.png +0 -0
- icons/not_turn_right.png +0 -0
- icons/not_turn_up.png +0 -0
- icons/turn_down.png +0 -0
- icons/turn_left.png +0 -0
- icons/turn_right.png +0 -0
- icons/turn_up.png +0 -0
- models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt +0 -0
- models/Astra/checkpoints/README.md +5 -0
- scripts/add_text_emb.py +161 -0
- scripts/add_text_emb_rl.py +161 -0
- scripts/add_text_emb_spatialvid.py +173 -0
- scripts/analyze_openx.py +243 -0
- scripts/analyze_pose.py +188 -0
- scripts/batch_drone.py +44 -0
- scripts/batch_infer.py +186 -0
- scripts/batch_nus.py +42 -0
- scripts/batch_rt.py +41 -0
- scripts/batch_spa.py +43 -0
- scripts/batch_walk.py +42 -0
- scripts/check.py +263 -0
- scripts/decode_openx.py +428 -0
- scripts/download_recam.py +7 -0
- scripts/download_wan2.1.py +5 -0
- scripts/encode_dynamic_videos.py +141 -0
- scripts/encode_openx.py +466 -0
- scripts/encode_rlbench_video.py +170 -0
- scripts/encode_sekai_video.py +162 -0
- scripts/encode_sekai_walking.py +249 -0
- scripts/encode_spatialvid.py +409 -0
- scripts/encode_spatialvid_first_frame.py +285 -0
- scripts/hud_logo.py +40 -0
- scripts/infer_demo.py +1458 -0
- scripts/infer_moe.py +1023 -0
- scripts/infer_moe_spatialvid.py +1008 -0
- scripts/infer_moe_test.py +976 -0
- scripts/infer_nus.py +500 -0
- scripts/infer_openx.py +614 -0
- scripts/infer_origin.py +1108 -0
- scripts/infer_recam.py +272 -0
- scripts/infer_rlbench.py +447 -0
icons/move_backward.png
ADDED
|
|
icons/move_forward.png
ADDED
|
|
icons/move_left.png
ADDED
|
|
icons/move_right.png
ADDED
|
|
icons/not_move_backward.png
ADDED
|
|
icons/not_move_forward.png
ADDED
|
|
icons/not_move_left.png
ADDED
|
|
icons/not_move_right.png
ADDED
|
|
icons/not_turn_down.png
ADDED
|
|
icons/not_turn_left.png
ADDED
|
|
icons/not_turn_right.png
ADDED
|
|
icons/not_turn_up.png
ADDED
|
|
icons/turn_down.png
ADDED
|
|
icons/turn_left.png
ADDED
|
|
icons/turn_right.png
ADDED
|
|
icons/turn_up.png
ADDED
|
|
models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt
ADDED
|
File without changes
|
models/Astra/checkpoints/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
# ReCamMaster: Camera-Controlled Generative Rendering from A Single Video
|
| 5 |
+
Please refer to the [Github](https://github.com/KwaiVGI/ReCamMaster) README for usage.
|
scripts/add_text_emb.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 96 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 97 |
+
|
| 98 |
+
if missing_keys:
|
| 99 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 100 |
+
else:
|
| 101 |
+
print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 102 |
+
continue
|
| 103 |
+
# with np.load(scene_cam_path) as data:
|
| 104 |
+
# cam_data = data.files
|
| 105 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 106 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 107 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# 加载和编码视频
|
| 112 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 113 |
+
# if video_frames is None:
|
| 114 |
+
# print(f"Failed to load video: {video_path}")
|
| 115 |
+
# continue
|
| 116 |
+
|
| 117 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 118 |
+
# print(video_frames.shape)
|
| 119 |
+
# 编码视频
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 122 |
+
|
| 123 |
+
# 编码文本
|
| 124 |
+
if processed_count == 0:
|
| 125 |
+
print('encode prompt!!!')
|
| 126 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
|
| 127 |
+
del encoder.pipe.prompter
|
| 128 |
+
|
| 129 |
+
data["prompt_emb"] = prompt_emb
|
| 130 |
+
|
| 131 |
+
print("已添加/更新 prompt_emb 元素")
|
| 132 |
+
|
| 133 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 134 |
+
torch.save(data, encoded_path)
|
| 135 |
+
|
| 136 |
+
# pdb.set_trace()
|
| 137 |
+
# 保存编码结果
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 141 |
+
processed_count += 1
|
| 142 |
+
|
| 143 |
+
# except Exception as e:
|
| 144 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 145 |
+
# continue
|
| 146 |
+
print(processed_count)
|
| 147 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
parser = argparse.ArgumentParser()
|
| 151 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 152 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 153 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 154 |
+
parser.add_argument("--vae_path", type=str,
|
| 155 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 156 |
+
|
| 157 |
+
parser.add_argument("--output_dir",type=str,
|
| 158 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/add_text_emb_rl.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 96 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 97 |
+
|
| 98 |
+
if missing_keys:
|
| 99 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 100 |
+
else:
|
| 101 |
+
print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 102 |
+
continue
|
| 103 |
+
# with np.load(scene_cam_path) as data:
|
| 104 |
+
# cam_data = data.files
|
| 105 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 106 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 107 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# 加载和编码视频
|
| 112 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 113 |
+
# if video_frames is None:
|
| 114 |
+
# print(f"Failed to load video: {video_path}")
|
| 115 |
+
# continue
|
| 116 |
+
|
| 117 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 118 |
+
# print(video_frames.shape)
|
| 119 |
+
# 编码视频
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 122 |
+
|
| 123 |
+
# 编码文本
|
| 124 |
+
if processed_count == 0:
|
| 125 |
+
print('encode prompt!!!')
|
| 126 |
+
prompt_emb = encoder.pipe.encode_prompt("a robotic arm executing precise manipulation tasks on a clean, organized desk")#A video of a scene shot using a drone's front camera + “A video of a scene shot using a pedestrian's front camera while walking”
|
| 127 |
+
del encoder.pipe.prompter
|
| 128 |
+
|
| 129 |
+
data["prompt_emb"] = prompt_emb
|
| 130 |
+
|
| 131 |
+
print("已添加/更新 prompt_emb 元素")
|
| 132 |
+
|
| 133 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 134 |
+
torch.save(data, encoded_path)
|
| 135 |
+
|
| 136 |
+
# pdb.set_trace()
|
| 137 |
+
# 保存编码结果
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 141 |
+
processed_count += 1
|
| 142 |
+
|
| 143 |
+
# except Exception as e:
|
| 144 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 145 |
+
# continue
|
| 146 |
+
print(processed_count)
|
| 147 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
parser = argparse.ArgumentParser()
|
| 151 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 152 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 153 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 154 |
+
parser.add_argument("--vae_path", type=str,
|
| 155 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 156 |
+
|
| 157 |
+
parser.add_argument("--output_dir",type=str,
|
| 158 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/add_text_emb_spatialvid.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
# print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False,
|
| 96 |
+
map_location="cpu")
|
| 97 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 98 |
+
|
| 99 |
+
if missing_keys:
|
| 100 |
+
print(f"警告: 文件 {encoded_path} 中缺少以下必要元素: {missing_keys}")
|
| 101 |
+
# else:
|
| 102 |
+
# # print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 103 |
+
# continue
|
| 104 |
+
# pdb.set_trace()
|
| 105 |
+
if data['prompt_emb']['context'].requires_grad:
|
| 106 |
+
print(f"警告: 文件 {encoded_path} 中存在含梯度变量,已消除")
|
| 107 |
+
|
| 108 |
+
data['prompt_emb']['context'] = data['prompt_emb']['context'].detach().clone()
|
| 109 |
+
|
| 110 |
+
# 双重保险:显式关闭梯度
|
| 111 |
+
data['prompt_emb']['context'].requires_grad_(False)
|
| 112 |
+
|
| 113 |
+
# 验证是否成功(可选)
|
| 114 |
+
assert not data['prompt_emb']['context'].requires_grad, "梯度仍未消除!"
|
| 115 |
+
torch.save(data, encoded_path)
|
| 116 |
+
# with np.load(scene_cam_path) as data:
|
| 117 |
+
# cam_data = data.files
|
| 118 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 119 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 120 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# 加载和编码视频
|
| 125 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 126 |
+
# if video_frames is None:
|
| 127 |
+
# print(f"Failed to load video: {video_path}")
|
| 128 |
+
# continue
|
| 129 |
+
|
| 130 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 131 |
+
# print(video_frames.shape)
|
| 132 |
+
# 编码视频
|
| 133 |
+
'''with torch.no_grad():
|
| 134 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 135 |
+
|
| 136 |
+
# 编码文本
|
| 137 |
+
if processed_count == 0:
|
| 138 |
+
print('encode prompt!!!')
|
| 139 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
|
| 140 |
+
del encoder.pipe.prompter
|
| 141 |
+
|
| 142 |
+
data["prompt_emb"] = prompt_emb
|
| 143 |
+
|
| 144 |
+
print("已添加/更新 prompt_emb 元素")
|
| 145 |
+
|
| 146 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 147 |
+
torch.save(data, encoded_path)
|
| 148 |
+
|
| 149 |
+
# pdb.set_trace()
|
| 150 |
+
# 保存编码结果
|
| 151 |
+
|
| 152 |
+
print(f"Saved encoded data: {encoded_path}")'''
|
| 153 |
+
processed_count += 1
|
| 154 |
+
|
| 155 |
+
# except Exception as e:
|
| 156 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 157 |
+
# continue
|
| 158 |
+
print(processed_count)
|
| 159 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
parser = argparse.ArgumentParser()
|
| 163 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 164 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 165 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 166 |
+
parser.add_argument("--vae_path", type=str,
|
| 167 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 168 |
+
|
| 169 |
+
parser.add_argument("--output_dir",type=str,
|
| 170 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 171 |
+
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/analyze_openx.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
def analyze_openx_dataset_frame_counts(dataset_path):
|
| 6 |
+
"""分析OpenX数据集中的帧数分布"""
|
| 7 |
+
|
| 8 |
+
print(f"🔧 分析OpenX数据集: {dataset_path}")
|
| 9 |
+
|
| 10 |
+
if not os.path.exists(dataset_path):
|
| 11 |
+
print(f" ⚠️ 路径不存在: {dataset_path}")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
episode_dirs = []
|
| 15 |
+
total_episodes = 0
|
| 16 |
+
valid_episodes = 0
|
| 17 |
+
|
| 18 |
+
# 收集所有episode目录
|
| 19 |
+
for item in os.listdir(dataset_path):
|
| 20 |
+
episode_dir = os.path.join(dataset_path, item)
|
| 21 |
+
if os.path.isdir(episode_dir):
|
| 22 |
+
total_episodes += 1
|
| 23 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 24 |
+
if os.path.exists(encoded_path):
|
| 25 |
+
episode_dirs.append(episode_dir)
|
| 26 |
+
valid_episodes += 1
|
| 27 |
+
|
| 28 |
+
print(f"📊 总episode数: {total_episodes}")
|
| 29 |
+
print(f"📊 有效episode数: {valid_episodes}")
|
| 30 |
+
|
| 31 |
+
if len(episode_dirs) == 0:
|
| 32 |
+
print("❌ 没有找到有效的episode")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# 统计帧数分布
|
| 36 |
+
frame_counts = []
|
| 37 |
+
less_than_10 = 0
|
| 38 |
+
less_than_8 = 0
|
| 39 |
+
less_than_5 = 0
|
| 40 |
+
error_count = 0
|
| 41 |
+
|
| 42 |
+
print("🔧 开始分析帧数分布...")
|
| 43 |
+
|
| 44 |
+
for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
|
| 45 |
+
try:
|
| 46 |
+
encoded_data = torch.load(
|
| 47 |
+
os.path.join(episode_dir, "encoded_video.pth"),
|
| 48 |
+
weights_only=False,
|
| 49 |
+
map_location="cpu"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
latents = encoded_data['latents'] # [C, T, H, W]
|
| 53 |
+
frame_count = latents.shape[1] # T维度
|
| 54 |
+
frame_counts.append(frame_count)
|
| 55 |
+
|
| 56 |
+
if frame_count < 10:
|
| 57 |
+
less_than_10 += 1
|
| 58 |
+
if frame_count < 8:
|
| 59 |
+
less_than_8 += 1
|
| 60 |
+
if frame_count < 5:
|
| 61 |
+
less_than_5 += 1
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
error_count += 1
|
| 65 |
+
if error_count <= 5: # 只打印前5个错误
|
| 66 |
+
print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
|
| 67 |
+
|
| 68 |
+
# 统计结果
|
| 69 |
+
total_valid = len(frame_counts)
|
| 70 |
+
print(f"\n📈 帧数分布统计:")
|
| 71 |
+
print(f" 总有效episodes: {total_valid}")
|
| 72 |
+
print(f" 错误episodes: {error_count}")
|
| 73 |
+
print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
|
| 74 |
+
print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
|
| 75 |
+
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
|
| 76 |
+
|
| 77 |
+
print(f"\n🎯 关键统计:")
|
| 78 |
+
print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
|
| 79 |
+
print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
|
| 80 |
+
print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
|
| 81 |
+
print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
|
| 82 |
+
|
| 83 |
+
# 详细分布
|
| 84 |
+
frame_counts.sort()
|
| 85 |
+
print(f"\n📊 详细帧数分布:")
|
| 86 |
+
|
| 87 |
+
# 按范围统计
|
| 88 |
+
ranges = [
|
| 89 |
+
(1, 4, "1-4帧"),
|
| 90 |
+
(5, 7, "5-7帧"),
|
| 91 |
+
(8, 9, "8-9帧"),
|
| 92 |
+
(10, 19, "10-19帧"),
|
| 93 |
+
(20, 49, "20-49帧"),
|
| 94 |
+
(50, 99, "50-99帧"),
|
| 95 |
+
(100, float('inf'), "100+帧")
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
for min_f, max_f, label in ranges:
|
| 99 |
+
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
|
| 100 |
+
percentage = count / total_valid * 100
|
| 101 |
+
print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
|
| 102 |
+
|
| 103 |
+
# 建议的训练配置
|
| 104 |
+
print(f"\n💡 训练配置建议:")
|
| 105 |
+
time_compression_ratio = 4
|
| 106 |
+
min_condition_compressed = 4 // time_compression_ratio # 1帧
|
| 107 |
+
target_frames_compressed = 32 // time_compression_ratio # 8帧
|
| 108 |
+
min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
|
| 109 |
+
|
| 110 |
+
usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
|
| 111 |
+
usable_percentage = usable_episodes / total_valid * 100
|
| 112 |
+
|
| 113 |
+
print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
|
| 114 |
+
print(f" 目标帧数(压缩后): {target_frames_compressed}")
|
| 115 |
+
print(f" 最小所需帧数(压缩后): {min_required_compressed}")
|
| 116 |
+
print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
|
| 117 |
+
|
| 118 |
+
# 保存详细统计到文件
|
| 119 |
+
output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
|
| 120 |
+
with open(output_file, 'w') as f:
|
| 121 |
+
f.write(f"OpenX Dataset Frame Count Analysis\n")
|
| 122 |
+
f.write(f"Dataset Path: {dataset_path}\n")
|
| 123 |
+
f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
|
| 124 |
+
|
| 125 |
+
f.write(f"Total Episodes: {total_episodes}\n")
|
| 126 |
+
f.write(f"Valid Episodes: {total_valid}\n")
|
| 127 |
+
f.write(f"Error Episodes: {error_count}\n\n")
|
| 128 |
+
|
| 129 |
+
f.write(f"Frame Count Statistics:\n")
|
| 130 |
+
f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
|
| 131 |
+
f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
|
| 132 |
+
f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
|
| 133 |
+
|
| 134 |
+
f.write(f"Key Statistics:\n")
|
| 135 |
+
f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
|
| 136 |
+
f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
|
| 137 |
+
f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
|
| 138 |
+
f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
|
| 139 |
+
|
| 140 |
+
f.write(f"Detailed Distribution:\n")
|
| 141 |
+
for min_f, max_f, label in ranges:
|
| 142 |
+
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
|
| 143 |
+
percentage = count / total_valid * 100
|
| 144 |
+
f.write(f" {label}: {count} ({percentage:.2f}%)\n")
|
| 145 |
+
|
| 146 |
+
f.write(f"\nTraining Configuration Recommendation:\n")
|
| 147 |
+
f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
|
| 148 |
+
|
| 149 |
+
# 写入所有帧数
|
| 150 |
+
f.write(f"\nAll Frame Counts:\n")
|
| 151 |
+
for i, count in enumerate(frame_counts):
|
| 152 |
+
f.write(f"{count}")
|
| 153 |
+
if (i + 1) % 20 == 0:
|
| 154 |
+
f.write("\n")
|
| 155 |
+
else:
|
| 156 |
+
f.write(", ")
|
| 157 |
+
|
| 158 |
+
print(f"\n💾 详细统计已保存到: {output_file}")
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
'total_valid': total_valid,
|
| 162 |
+
'less_than_10': less_than_10,
|
| 163 |
+
'less_than_8': less_than_8,
|
| 164 |
+
'less_than_5': less_than_5,
|
| 165 |
+
'frame_counts': frame_counts,
|
| 166 |
+
'usable_episodes': usable_episodes
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
def quick_sample_analysis(dataset_path, sample_size=1000):
|
| 170 |
+
"""快速采样分析,用于大数据集的初步估计"""
|
| 171 |
+
|
| 172 |
+
print(f"🚀 快速采样分析 (样本数: {sample_size})")
|
| 173 |
+
|
| 174 |
+
episode_dirs = []
|
| 175 |
+
for item in os.listdir(dataset_path):
|
| 176 |
+
episode_dir = os.path.join(dataset_path, item)
|
| 177 |
+
if os.path.isdir(episode_dir):
|
| 178 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 179 |
+
if os.path.exists(encoded_path):
|
| 180 |
+
episode_dirs.append(episode_dir)
|
| 181 |
+
|
| 182 |
+
if len(episode_dirs) == 0:
|
| 183 |
+
print("❌ 没有找到有效的episode")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
# 随机采样
|
| 187 |
+
import random
|
| 188 |
+
sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
|
| 189 |
+
|
| 190 |
+
frame_counts = []
|
| 191 |
+
less_than_10 = 0
|
| 192 |
+
|
| 193 |
+
for episode_dir in tqdm(sample_dirs, desc="采样分析"):
|
| 194 |
+
try:
|
| 195 |
+
encoded_data = torch.load(
|
| 196 |
+
os.path.join(episode_dir, "encoded_video.pth"),
|
| 197 |
+
weights_only=False,
|
| 198 |
+
map_location="cpu"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
frame_count = encoded_data['latents'].shape[1]
|
| 202 |
+
frame_counts.append(frame_count)
|
| 203 |
+
|
| 204 |
+
if frame_count < 10:
|
| 205 |
+
less_than_10 += 1
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
total_sample = len(frame_counts)
|
| 211 |
+
percentage_less_than_10 = less_than_10 / total_sample * 100
|
| 212 |
+
|
| 213 |
+
print(f"📊 采样结果:")
|
| 214 |
+
print(f" 采样数量: {total_sample}")
|
| 215 |
+
print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
|
| 216 |
+
print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
|
| 217 |
+
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
|
| 218 |
+
|
| 219 |
+
# 估算全数据集
|
| 220 |
+
total_episodes = len(episode_dirs)
|
| 221 |
+
estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
|
| 222 |
+
|
| 223 |
+
print(f"\n🔮 全数据集估算:")
|
| 224 |
+
print(f" 总episodes: {total_episodes}")
|
| 225 |
+
print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
|
| 226 |
+
print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
import argparse
|
| 230 |
+
|
| 231 |
+
parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
|
| 232 |
+
parser.add_argument("--dataset_path", type=str,
|
| 233 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
|
| 234 |
+
help="OpenX编码数据集路径")
|
| 235 |
+
parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
|
| 236 |
+
parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
|
| 237 |
+
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
if args.quick:
|
| 241 |
+
quick_sample_analysis(args.dataset_path, args.sample_size)
|
| 242 |
+
else:
|
| 243 |
+
analyze_openx_dataset_frame_counts(args.dataset_path)
|
scripts/analyze_pose.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pose_classifier import PoseClassifier
|
| 6 |
+
import torch
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
def analyze_turning_patterns_detailed(dataset_path, num_samples=50):
|
| 10 |
+
"""详细分析转弯模式,基于相对于reference的pose变化"""
|
| 11 |
+
classifier = PoseClassifier()
|
| 12 |
+
samples_path = os.path.join(dataset_path, "samples")
|
| 13 |
+
|
| 14 |
+
all_analyses = []
|
| 15 |
+
sample_count = 0
|
| 16 |
+
|
| 17 |
+
# 用于统计每个类别的样本
|
| 18 |
+
class_samples = defaultdict(list)
|
| 19 |
+
|
| 20 |
+
print("=== 开始分析样本(基于相对于reference的变化)===")
|
| 21 |
+
|
| 22 |
+
for item in sorted(os.listdir(samples_path)): # 排序以便有序输出
|
| 23 |
+
if sample_count >= num_samples:
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
sample_dir = os.path.join(samples_path, item)
|
| 27 |
+
if os.path.isdir(sample_dir):
|
| 28 |
+
poses_path = os.path.join(sample_dir, "poses.json")
|
| 29 |
+
if os.path.exists(poses_path):
|
| 30 |
+
try:
|
| 31 |
+
with open(poses_path, 'r') as f:
|
| 32 |
+
poses_data = json.load(f)
|
| 33 |
+
|
| 34 |
+
target_relative_poses = poses_data['target_relative_poses']
|
| 35 |
+
|
| 36 |
+
if len(target_relative_poses) > 0:
|
| 37 |
+
# 🔧 创建相对pose向量(已经是相对于reference的)
|
| 38 |
+
pose_vecs = []
|
| 39 |
+
for pose_data in target_relative_poses:
|
| 40 |
+
# 相对位移(已经是相对于reference计算的)
|
| 41 |
+
translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
|
| 42 |
+
|
| 43 |
+
# 🔧 相对旋转(需要从current和reference计算)
|
| 44 |
+
current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
|
| 45 |
+
reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
|
| 46 |
+
|
| 47 |
+
# 计算相对旋转:q_relative = q_ref^-1 * q_current
|
| 48 |
+
relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
|
| 49 |
+
|
| 50 |
+
# 组合为7D向量:[relative_translation, relative_rotation]
|
| 51 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0)
|
| 52 |
+
pose_vecs.append(pose_vec)
|
| 53 |
+
|
| 54 |
+
if pose_vecs:
|
| 55 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 56 |
+
|
| 57 |
+
# 🔧 使用新的分析方法
|
| 58 |
+
analysis = classifier.analyze_pose_sequence(pose_sequence)
|
| 59 |
+
analysis['sample_name'] = item
|
| 60 |
+
all_analyses.append(analysis)
|
| 61 |
+
|
| 62 |
+
# 🔧 详细输出每个样本的分类信息
|
| 63 |
+
print(f"\n--- 样本 {sample_count + 1}: {item} ---")
|
| 64 |
+
print(f"总帧数: {analysis['total_frames']}")
|
| 65 |
+
print(f"总距离: {analysis['total_distance']:.4f}")
|
| 66 |
+
|
| 67 |
+
# 分类分布
|
| 68 |
+
class_dist = analysis['class_distribution']
|
| 69 |
+
print(f"分类分布:")
|
| 70 |
+
for class_name, count in class_dist.items():
|
| 71 |
+
percentage = count / analysis['total_frames'] * 100
|
| 72 |
+
print(f" {class_name}: {count} 帧 ({percentage:.1f}%)")
|
| 73 |
+
|
| 74 |
+
# 🔧 调试前几个pose的分类过程
|
| 75 |
+
print(f"前3帧的详细分类过程:")
|
| 76 |
+
for i in range(min(3, len(pose_vecs))):
|
| 77 |
+
debug_info = classifier.debug_single_pose(
|
| 78 |
+
pose_vecs[i][:3], pose_vecs[i][3:7]
|
| 79 |
+
)
|
| 80 |
+
print(f" 帧{i}: {debug_info['classification']} "
|
| 81 |
+
f"(yaw: {debug_info['yaw_angle_deg']:.2f}°, "
|
| 82 |
+
f"forward: {debug_info['forward_movement']:.3f})")
|
| 83 |
+
|
| 84 |
+
# 运动段落
|
| 85 |
+
print(f"运动段落:")
|
| 86 |
+
for i, segment in enumerate(analysis['motion_segments']):
|
| 87 |
+
print(f" 段落{i+1}: {segment['class']} (帧 {segment['start_frame']}-{segment['end_frame']}, 持续 {segment['duration']} 帧)")
|
| 88 |
+
|
| 89 |
+
# 🔧 确定主要运动类型
|
| 90 |
+
dominant_class = max(class_dist.items(), key=lambda x: x[1])
|
| 91 |
+
dominant_class_name = dominant_class[0]
|
| 92 |
+
dominant_percentage = dominant_class[1] / analysis['total_frames'] * 100
|
| 93 |
+
|
| 94 |
+
print(f"主要运动类型: {dominant_class_name} ({dominant_percentage:.1f}%)")
|
| 95 |
+
|
| 96 |
+
# 将样本添加到对应类别
|
| 97 |
+
class_samples[dominant_class_name].append({
|
| 98 |
+
'name': item,
|
| 99 |
+
'percentage': dominant_percentage,
|
| 100 |
+
'analysis': analysis
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
sample_count += 1
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"❌ 处理样本 {item} 时出错: {e}")
|
| 107 |
+
|
| 108 |
+
print("\n" + "="*60)
|
| 109 |
+
print("=== 按类别分组的样本统计(基于相对于reference的变化)===")
|
| 110 |
+
|
| 111 |
+
# 🔧 按类别输出样本列表
|
| 112 |
+
for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
|
| 113 |
+
samples = class_samples[class_name]
|
| 114 |
+
print(f"\n🔸 {class_name.upper()} 类样本 (共 {len(samples)} 个):")
|
| 115 |
+
|
| 116 |
+
if samples:
|
| 117 |
+
# 按主要类别占比排序
|
| 118 |
+
samples.sort(key=lambda x: x['percentage'], reverse=True)
|
| 119 |
+
|
| 120 |
+
for i, sample_info in enumerate(samples, 1):
|
| 121 |
+
print(f" {i:2d}. {sample_info['name']} ({sample_info['percentage']:.1f}%)")
|
| 122 |
+
|
| 123 |
+
# 显示详细的段落信息
|
| 124 |
+
segments = sample_info['analysis']['motion_segments']
|
| 125 |
+
segment_summary = []
|
| 126 |
+
for seg in segments:
|
| 127 |
+
if seg['duration'] >= 2: # 只显示持续时间>=2帧的段落
|
| 128 |
+
segment_summary.append(f"{seg['class']}({seg['duration']})")
|
| 129 |
+
|
| 130 |
+
if segment_summary:
|
| 131 |
+
print(f" 段落: {' -> '.join(segment_summary)}")
|
| 132 |
+
else:
|
| 133 |
+
print(" (无样本)")
|
| 134 |
+
|
| 135 |
+
# 🔧 统计总体模式
|
| 136 |
+
print(f"\n" + "="*60)
|
| 137 |
+
print("=== 总体统计 ===")
|
| 138 |
+
|
| 139 |
+
total_forward = sum(a['class_distribution']['forward'] for a in all_analyses)
|
| 140 |
+
total_backward = sum(a['class_distribution']['backward'] for a in all_analyses)
|
| 141 |
+
total_left_turn = sum(a['class_distribution']['left_turn'] for a in all_analyses)
|
| 142 |
+
total_right_turn = sum(a['class_distribution']['right_turn'] for a in all_analyses)
|
| 143 |
+
total_frames = total_forward + total_backward + total_left_turn + total_right_turn
|
| 144 |
+
|
| 145 |
+
print(f"总样本数: {len(all_analyses)}")
|
| 146 |
+
print(f"总帧数: {total_frames}")
|
| 147 |
+
print(f"Forward: {total_forward} 帧 ({total_forward/total_frames*100:.1f}%)")
|
| 148 |
+
print(f"Backward: {total_backward} 帧 ({total_backward/total_frames*100:.1f}%)")
|
| 149 |
+
print(f"Left Turn: {total_left_turn} 帧 ({total_left_turn/total_frames*100:.1f}%)")
|
| 150 |
+
print(f"Right Turn: {total_right_turn} 帧 ({total_right_turn/total_frames*100:.1f}%)")
|
| 151 |
+
|
| 152 |
+
# 🔧 样本分布统计
|
| 153 |
+
print(f"\n按主要类型的样本分布:")
|
| 154 |
+
for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
|
| 155 |
+
count = len(class_samples[class_name])
|
| 156 |
+
percentage = count / len(all_analyses) * 100 if all_analyses else 0
|
| 157 |
+
print(f" {class_name}: {count} 样本 ({percentage:.1f}%)")
|
| 158 |
+
|
| 159 |
+
return all_analyses, class_samples
|
| 160 |
+
|
| 161 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 162 |
+
"""计算相对旋转四元数"""
|
| 163 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 164 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 165 |
+
|
| 166 |
+
# 计算参考旋转的逆 (q_ref^-1)
|
| 167 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 168 |
+
|
| 169 |
+
# 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
|
| 170 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 171 |
+
w2, x2, y2, z2 = q_current
|
| 172 |
+
|
| 173 |
+
relative_rotation = torch.tensor([
|
| 174 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 175 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 176 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 177 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
return relative_rotation
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
dataset_path = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_2"
|
| 184 |
+
|
| 185 |
+
print("开始详细分析pose分类(基于相对于reference的变化)...")
|
| 186 |
+
all_analyses, class_samples = analyze_turning_patterns_detailed(dataset_path, num_samples=4000)
|
| 187 |
+
|
| 188 |
+
print(f"\n🎉 分析完成! 共处理 {len(all_analyses)} 个样本")
|
scripts/batch_drone.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_drone_first"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_origin.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "exploring the world",
|
| 33 |
+
"--modality_type", "sekai",
|
| 34 |
+
"--direction", "right",
|
| 35 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt",
|
| 36 |
+
"--use_gt_prompt"
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# 仅使用第二张 GPU
|
| 40 |
+
env = os.environ.copy()
|
| 41 |
+
env["CUDA_VISIBLE_DEVICES"] = "0"
|
| 42 |
+
|
| 43 |
+
# 执行推理
|
| 44 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_infer.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
def find_video_files(videos_dir):
|
| 8 |
+
"""查找视频目录下的所有视频文件"""
|
| 9 |
+
video_extensions = ['.mp4']
|
| 10 |
+
video_files = []
|
| 11 |
+
|
| 12 |
+
for ext in video_extensions:
|
| 13 |
+
pattern = os.path.join(videos_dir, f"*{ext}")
|
| 14 |
+
video_files.extend(glob.glob(pattern))
|
| 15 |
+
|
| 16 |
+
return sorted(video_files)
|
| 17 |
+
|
| 18 |
+
def run_inference(condition_video, direction, dit_path, output_dir):
|
| 19 |
+
"""运行单个推理任务"""
|
| 20 |
+
# 构建输出文件名
|
| 21 |
+
input_filename = os.path.basename(condition_video)
|
| 22 |
+
name_parts = os.path.splitext(input_filename)
|
| 23 |
+
output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 24 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 25 |
+
|
| 26 |
+
# 构建推理命令
|
| 27 |
+
cmd = [
|
| 28 |
+
"python", "infer_nus.py",
|
| 29 |
+
"--condition_video", condition_video,
|
| 30 |
+
"--direction", direction,
|
| 31 |
+
"--dit_path", dit_path,
|
| 32 |
+
"--output_path", output_path,
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
print(f"🎬 生成 {direction} 方向视频: {input_filename} -> {output_filename}")
|
| 36 |
+
print(f" 命令: {' '.join(cmd)}")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# 运行推理
|
| 40 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 41 |
+
print(f"✅ 成功生成: {output_path}")
|
| 42 |
+
return True
|
| 43 |
+
except subprocess.CalledProcessError as e:
|
| 44 |
+
print(f"❌ 生成失败: {e}")
|
| 45 |
+
print(f" 错误输出: {e.stderr}")
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
def batch_inference(args):
|
| 49 |
+
"""批量推理主函数"""
|
| 50 |
+
videos_dir = args.videos_dir
|
| 51 |
+
output_dir = args.output_dir
|
| 52 |
+
directions = args.directions
|
| 53 |
+
dit_path = args.dit_path
|
| 54 |
+
|
| 55 |
+
# 检查输入目录
|
| 56 |
+
if not os.path.exists(videos_dir):
|
| 57 |
+
print(f"❌ 视频目录不存在: {videos_dir}")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# 创建输出目录
|
| 61 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 62 |
+
print(f"📁 输出目录: {output_dir}")
|
| 63 |
+
|
| 64 |
+
# 查找所有视频文件
|
| 65 |
+
video_files = find_video_files(videos_dir)
|
| 66 |
+
|
| 67 |
+
if not video_files:
|
| 68 |
+
print(f"❌ 在 {videos_dir} 中没有找到视频文件")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
print(f"🎥 找到 {len(video_files)} 个视频文件:")
|
| 72 |
+
for video in video_files:
|
| 73 |
+
print(f" - {os.path.basename(video)}")
|
| 74 |
+
|
| 75 |
+
print(f"🎯 将为每个视频生成以下方向: {', '.join(directions)}")
|
| 76 |
+
print(f"📊 总共将生成 {len(video_files) * len(directions)} 个视频")
|
| 77 |
+
|
| 78 |
+
# 统计信息
|
| 79 |
+
total_tasks = len(video_files) * len(directions)
|
| 80 |
+
completed_tasks = 0
|
| 81 |
+
failed_tasks = 0
|
| 82 |
+
|
| 83 |
+
# 批量处理
|
| 84 |
+
for i, video_file in enumerate(video_files, 1):
|
| 85 |
+
print(f"\n{'='*60}")
|
| 86 |
+
print(f"处理视频 {i}/{len(video_files)}: {os.path.basename(video_file)}")
|
| 87 |
+
print(f"{'='*60}")
|
| 88 |
+
|
| 89 |
+
for j, direction in enumerate(directions, 1):
|
| 90 |
+
print(f"\n--- 方向 {j}/{len(directions)}: {direction} ---")
|
| 91 |
+
|
| 92 |
+
# 检查输出文件是否已存在
|
| 93 |
+
input_filename = os.path.basename(video_file)
|
| 94 |
+
name_parts = os.path.splitext(input_filename)
|
| 95 |
+
output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 96 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 97 |
+
|
| 98 |
+
if os.path.exists(output_path) and not args.overwrite:
|
| 99 |
+
print(f"⏭️ 文件已存在,跳过: {output_filename}")
|
| 100 |
+
completed_tasks += 1
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# 运行推理
|
| 104 |
+
success = run_inference(
|
| 105 |
+
condition_video=video_file,
|
| 106 |
+
direction=direction,
|
| 107 |
+
dit_path=dit_path,
|
| 108 |
+
output_dir=output_dir,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if success:
|
| 112 |
+
completed_tasks += 1
|
| 113 |
+
else:
|
| 114 |
+
failed_tasks += 1
|
| 115 |
+
|
| 116 |
+
# 显示进度
|
| 117 |
+
current_progress = completed_tasks + failed_tasks
|
| 118 |
+
print(f"📈 进度: {current_progress}/{total_tasks} "
|
| 119 |
+
f"(成功: {completed_tasks}, 失败: {failed_tasks})")
|
| 120 |
+
|
| 121 |
+
# 最终统计
|
| 122 |
+
print(f"\n{'='*60}")
|
| 123 |
+
print(f"🎉 批量推理完成!")
|
| 124 |
+
print(f"📊 总任务数: {total_tasks}")
|
| 125 |
+
print(f"✅ 成功: {completed_tasks}")
|
| 126 |
+
print(f"❌ 失败: {failed_tasks}")
|
| 127 |
+
print(f"📁 输出目录: {output_dir}")
|
| 128 |
+
|
| 129 |
+
if failed_tasks > 0:
|
| 130 |
+
print(f"⚠️ 有 {failed_tasks} 个任务失败,请检查日志")
|
| 131 |
+
|
| 132 |
+
# 列出生成的文件
|
| 133 |
+
if completed_tasks > 0:
|
| 134 |
+
print(f"\n📋 生成的文件:")
|
| 135 |
+
generated_files = glob.glob(os.path.join(output_dir, "*.mp4"))
|
| 136 |
+
for file_path in sorted(generated_files):
|
| 137 |
+
print(f" - {os.path.basename(file_path)}")
|
| 138 |
+
|
| 139 |
+
def main():
|
| 140 |
+
parser = argparse.ArgumentParser(description="批量对nus/videos目录下的所有视频生成不同方向的输出")
|
| 141 |
+
|
| 142 |
+
parser.add_argument("--videos_dir", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032",
|
| 143 |
+
help="输入视频目录路径")
|
| 144 |
+
|
| 145 |
+
parser.add_argument("--output_dir", type=str, default="nus/infer_results/batch_dynamic_4032_noise",
|
| 146 |
+
help="输出视频目录路径")
|
| 147 |
+
|
| 148 |
+
parser.add_argument("--directions", nargs="+",
|
| 149 |
+
default=["left_turn", "right_turn"],
|
| 150 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 151 |
+
help="要生成的方向列表")
|
| 152 |
+
|
| 153 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
|
| 154 |
+
help="训练好的DiT模型路径")
|
| 155 |
+
|
| 156 |
+
parser.add_argument("--overwrite", action="store_true",
|
| 157 |
+
help="是否覆盖已存在的输出文件")
|
| 158 |
+
|
| 159 |
+
parser.add_argument("--dry_run", action="store_true",
|
| 160 |
+
help="只显示将要执行的任务,不实际运行")
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
|
| 164 |
+
if args.dry_run:
|
| 165 |
+
print("🔍 预览模式 - 只显示任务,不执行")
|
| 166 |
+
videos_dir = args.videos_dir
|
| 167 |
+
video_files = find_video_files(videos_dir)
|
| 168 |
+
|
| 169 |
+
print(f"📁 输入目录: {videos_dir}")
|
| 170 |
+
print(f"📁 输出目录: {args.output_dir}")
|
| 171 |
+
print(f"🎥 找到视频: {len(video_files)} 个")
|
| 172 |
+
print(f"🎯 生成方向: {', '.join(args.directions)}")
|
| 173 |
+
print(f"📊 总任务数: {len(video_files) * len(args.directions)}")
|
| 174 |
+
|
| 175 |
+
print(f"\n将要执行的任务:")
|
| 176 |
+
for video in video_files:
|
| 177 |
+
for direction in args.directions:
|
| 178 |
+
input_name = os.path.basename(video)
|
| 179 |
+
name_parts = os.path.splitext(input_name)
|
| 180 |
+
output_name = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 181 |
+
print(f" {input_name} -> {output_name} ({direction})")
|
| 182 |
+
else:
|
| 183 |
+
batch_inference(args)
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
scripts/batch_nus.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_nus_right_2"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "a car is driving",
|
| 33 |
+
"--modality_type", "nuscenes",
|
| 34 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# 仅使用第二张 GPU
|
| 38 |
+
env = os.environ.copy()
|
| 39 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 40 |
+
|
| 41 |
+
# 执行推理
|
| 42 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_rt.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_RT"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "A robotic arm is moving the object",
|
| 33 |
+
"--modality_type", "openx",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# 仅使用第二张 GPU
|
| 37 |
+
env = os.environ.copy()
|
| 38 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 39 |
+
|
| 40 |
+
# 执行推理
|
| 41 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_spa.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_right"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "exploring the world",
|
| 33 |
+
"--modality_type", "sekai",
|
| 34 |
+
#"--direction", "left",
|
| 35 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# 仅使用第二张 GPU
|
| 39 |
+
env = os.environ.copy()
|
| 40 |
+
env["CUDA_VISIBLE_DEVICES"] = "0"
|
| 41 |
+
|
| 42 |
+
# 执行推理
|
| 43 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_walk.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_walk"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "a car is driving",
|
| 33 |
+
"--modality_type", "nuscenes",
|
| 34 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# 仅使用第二张 GPU
|
| 38 |
+
env = os.environ.copy()
|
| 39 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 40 |
+
|
| 41 |
+
# 执行推理
|
| 42 |
+
subprocess.run(cmd, env=env)
|
scripts/check.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
def load_checkpoint(ckpt_path):
|
| 8 |
+
"""加载检查点文件"""
|
| 9 |
+
if not os.path.exists(ckpt_path):
|
| 10 |
+
return None
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
state_dict = torch.load(ckpt_path, map_location='cpu')
|
| 14 |
+
return state_dict
|
| 15 |
+
except Exception as e:
|
| 16 |
+
print(f"❌ 加载检查点失败: {e}")
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
def compare_parameters(state_dict1, state_dict2, threshold=1e-8):
|
| 20 |
+
"""比较两个状态字典的参数差异"""
|
| 21 |
+
if state_dict1 is None or state_dict2 is None:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
updated_params = {}
|
| 25 |
+
unchanged_params = {}
|
| 26 |
+
|
| 27 |
+
for name, param1 in state_dict1.items():
|
| 28 |
+
if name in state_dict2:
|
| 29 |
+
param2 = state_dict2[name]
|
| 30 |
+
|
| 31 |
+
# 计算参数差异
|
| 32 |
+
diff = torch.abs(param1 - param2)
|
| 33 |
+
max_diff = torch.max(diff).item()
|
| 34 |
+
mean_diff = torch.mean(diff).item()
|
| 35 |
+
|
| 36 |
+
if max_diff > threshold:
|
| 37 |
+
updated_params[name] = {
|
| 38 |
+
'max_diff': max_diff,
|
| 39 |
+
'mean_diff': mean_diff,
|
| 40 |
+
'shape': param1.shape
|
| 41 |
+
}
|
| 42 |
+
else:
|
| 43 |
+
unchanged_params[name] = {
|
| 44 |
+
'max_diff': max_diff,
|
| 45 |
+
'mean_diff': mean_diff,
|
| 46 |
+
'shape': param1.shape
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
return updated_params, unchanged_params
|
| 50 |
+
|
| 51 |
+
def categorize_parameters(param_dict):
|
| 52 |
+
"""将参数按类型分类"""
|
| 53 |
+
categories = {
|
| 54 |
+
'moe_related': {},
|
| 55 |
+
'camera_related': {},
|
| 56 |
+
'framepack_related': {},
|
| 57 |
+
'attention': {},
|
| 58 |
+
'other': {}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
for name, info in param_dict.items():
|
| 62 |
+
if any(keyword in name.lower() for keyword in ['moe', 'gate', 'expert', 'processor']):
|
| 63 |
+
categories['moe_related'][name] = info
|
| 64 |
+
elif any(keyword in name.lower() for keyword in ['cam_encoder', 'projector', 'camera']):
|
| 65 |
+
categories['camera_related'][name] = info
|
| 66 |
+
elif any(keyword in name.lower() for keyword in ['clean_x_embedder', 'framepack']):
|
| 67 |
+
categories['framepack_related'][name] = info
|
| 68 |
+
elif any(keyword in name.lower() for keyword in ['attn', 'attention']):
|
| 69 |
+
categories['attention'][name] = info
|
| 70 |
+
else:
|
| 71 |
+
categories['other'][name] = info
|
| 72 |
+
|
| 73 |
+
return categories
|
| 74 |
+
|
| 75 |
+
def print_category_summary(category_name, params, color_code=''):
|
| 76 |
+
"""打印某类参数的摘要"""
|
| 77 |
+
if not params:
|
| 78 |
+
print(f"{color_code} {category_name}: 无参数")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
total_params = len(params)
|
| 82 |
+
max_diffs = [info['max_diff'] for info in params.values()]
|
| 83 |
+
mean_diffs = [info['mean_diff'] for info in params.values()]
|
| 84 |
+
|
| 85 |
+
print(f"{color_code} {category_name} ({total_params} 个参数):")
|
| 86 |
+
print(f" 最大差异范围: {min(max_diffs):.2e} ~ {max(max_diffs):.2e}")
|
| 87 |
+
print(f" 平均差异范围: {min(mean_diffs):.2e} ~ {max(mean_diffs):.2e}")
|
| 88 |
+
|
| 89 |
+
# 显示前5个最大变化的参数
|
| 90 |
+
sorted_params = sorted(params.items(), key=lambda x: x[1]['max_diff'], reverse=True)
|
| 91 |
+
print(f" 变化最大的参数:")
|
| 92 |
+
for i, (name, info) in enumerate(sorted_params[:100]):
|
| 93 |
+
shape_str = 'x'.join(map(str, info['shape']))
|
| 94 |
+
print(f" {i+1}. {name} [{shape_str}]: max_diff={info['max_diff']:.2e}")
|
| 95 |
+
|
| 96 |
+
def monitor_training(checkpoint_dir, check_interval=60):
|
| 97 |
+
"""监控训练过程中的参数更新"""
|
| 98 |
+
print(f"🔍 开始监控训练进度...")
|
| 99 |
+
print(f"📁 检查点目录: {checkpoint_dir}")
|
| 100 |
+
print(f"⏰ 检查间隔: {check_interval}秒")
|
| 101 |
+
print("=" * 80)
|
| 102 |
+
|
| 103 |
+
previous_ckpt = None
|
| 104 |
+
previous_step = -1
|
| 105 |
+
|
| 106 |
+
while True:
|
| 107 |
+
try:
|
| 108 |
+
# 查找最新的检查点
|
| 109 |
+
if not os.path.exists(checkpoint_dir):
|
| 110 |
+
print(f"❌ 检查点目录不存在: {checkpoint_dir}")
|
| 111 |
+
time.sleep(check_interval)
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('step') and f.endswith('.ckpt')]
|
| 115 |
+
if not ckpt_files:
|
| 116 |
+
print("⏳ 未找到检查点文件,等待中...")
|
| 117 |
+
time.sleep(check_interval)
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# 按步数排序,获取最新的
|
| 121 |
+
ckpt_files.sort(key=lambda x: int(x.replace('step', '').replace('.ckpt', '')))
|
| 122 |
+
latest_ckpt_file = ckpt_files[-1]
|
| 123 |
+
latest_ckpt_path = os.path.join(checkpoint_dir, latest_ckpt_file)
|
| 124 |
+
|
| 125 |
+
# 提取步数
|
| 126 |
+
current_step = int(latest_ckpt_file.replace('step', '').replace('.ckpt', ''))
|
| 127 |
+
|
| 128 |
+
if current_step <= previous_step:
|
| 129 |
+
print(f"⏳ 等待新的检查点... (当前: step{current_step})")
|
| 130 |
+
time.sleep(check_interval)
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
print(f"\n🔍 发现新检查点: {latest_ckpt_file}")
|
| 134 |
+
|
| 135 |
+
# 加载当前检查点
|
| 136 |
+
current_state_dict = load_checkpoint(latest_ckpt_path)
|
| 137 |
+
if current_state_dict is None:
|
| 138 |
+
print("❌ 无法加载当前检查点")
|
| 139 |
+
time.sleep(check_interval)
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
if previous_ckpt is not None:
|
| 143 |
+
print(f"📊 比较 step{previous_step} -> step{current_step}")
|
| 144 |
+
|
| 145 |
+
# 比较参数
|
| 146 |
+
updated_params, unchanged_params = compare_parameters(
|
| 147 |
+
previous_ckpt, current_state_dict, threshold=1e-8
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if updated_params is None:
|
| 151 |
+
print("❌ 参数比较失败")
|
| 152 |
+
else:
|
| 153 |
+
# 分类显示结果
|
| 154 |
+
updated_categories = categorize_parameters(updated_params)
|
| 155 |
+
unchanged_categories = categorize_parameters(unchanged_params)
|
| 156 |
+
|
| 157 |
+
print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
|
| 158 |
+
print_category_summary("MoE相关", updated_categories['moe_related'], '🔥')
|
| 159 |
+
print_category_summary("Camera相关", updated_categories['camera_related'], '📷')
|
| 160 |
+
print_category_summary("FramePack相关", updated_categories['framepack_related'], '🎞️')
|
| 161 |
+
print_category_summary("注意力相关", updated_categories['attention'], '👁️')
|
| 162 |
+
print_category_summary("其他", updated_categories['other'], '📦')
|
| 163 |
+
|
| 164 |
+
print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
|
| 165 |
+
print_category_summary("MoE相关", unchanged_categories['moe_related'], '❄️')
|
| 166 |
+
print_category_summary("Camera相关", unchanged_categories['camera_related'], '❄️')
|
| 167 |
+
print_category_summary("FramePack相关", unchanged_categories['framepack_related'], '❄️')
|
| 168 |
+
print_category_summary("注意力相关", unchanged_categories['attention'], '❄️')
|
| 169 |
+
print_category_summary("其他", unchanged_categories['other'], '❄️')
|
| 170 |
+
|
| 171 |
+
# 检查关键组件是否在更新
|
| 172 |
+
critical_keywords = ['moe', 'cam_encoder', 'projector', 'clean_x_embedder']
|
| 173 |
+
critical_updated = any(
|
| 174 |
+
any(keyword in name.lower() for keyword in critical_keywords)
|
| 175 |
+
for name in updated_params.keys()
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if critical_updated:
|
| 179 |
+
print("\n✅ 关键组件正在更新!")
|
| 180 |
+
else:
|
| 181 |
+
print("\n❌ 警告:关键组件可能未在更新!")
|
| 182 |
+
|
| 183 |
+
# 计算更新率
|
| 184 |
+
total_params = len(updated_params) + len(unchanged_params)
|
| 185 |
+
update_rate = len(updated_params) / total_params * 100
|
| 186 |
+
print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
|
| 187 |
+
|
| 188 |
+
# 保存当前状态用于下次比较
|
| 189 |
+
previous_ckpt = current_state_dict
|
| 190 |
+
previous_step = current_step
|
| 191 |
+
|
| 192 |
+
print("=" * 80)
|
| 193 |
+
time.sleep(check_interval)
|
| 194 |
+
|
| 195 |
+
except KeyboardInterrupt:
|
| 196 |
+
print("\n👋 监控已停止")
|
| 197 |
+
break
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"❌ 监控过程中出错: {e}")
|
| 200 |
+
time.sleep(check_interval)
|
| 201 |
+
|
| 202 |
+
def compare_two_checkpoints(ckpt1_path, ckpt2_path):
|
| 203 |
+
"""比较两个特定的检查点"""
|
| 204 |
+
print(f"🔍 比较两个检查点:")
|
| 205 |
+
print(f" 检查点1: {ckpt1_path}")
|
| 206 |
+
print(f" 检查点2: {ckpt2_path}")
|
| 207 |
+
print("=" * 80)
|
| 208 |
+
|
| 209 |
+
# 加载检查点
|
| 210 |
+
state_dict1 = load_checkpoint(ckpt1_path)
|
| 211 |
+
state_dict2 = load_checkpoint(ckpt2_path)
|
| 212 |
+
|
| 213 |
+
if state_dict1 is None or state_dict2 is None:
|
| 214 |
+
print("❌ 无法加载检查点文件")
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
# 比较参数
|
| 218 |
+
updated_params, unchanged_params = compare_parameters(state_dict1, state_dict2)
|
| 219 |
+
|
| 220 |
+
if updated_params is None:
|
| 221 |
+
print("❌ 参数比较失败")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
# 分类显示结果
|
| 225 |
+
updated_categories = categorize_parameters(updated_params)
|
| 226 |
+
unchanged_categories = categorize_parameters(unchanged_params)
|
| 227 |
+
|
| 228 |
+
print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
|
| 229 |
+
for category_name, params in updated_categories.items():
|
| 230 |
+
print_category_summary(category_name.replace('_', ' ').title(), params, '🔥')
|
| 231 |
+
|
| 232 |
+
print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
|
| 233 |
+
for category_name, params in unchanged_categories.items():
|
| 234 |
+
print_category_summary(category_name.replace('_', ' ').title(), params, '❄️')
|
| 235 |
+
|
| 236 |
+
# 计算更新率
|
| 237 |
+
total_params = len(updated_params) + len(unchanged_params)
|
| 238 |
+
update_rate = len(updated_params) / total_params * 100
|
| 239 |
+
print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
|
| 240 |
+
|
| 241 |
+
if __name__ == '__main__':
|
| 242 |
+
parser = argparse.ArgumentParser(description="检查模型参数更新情况")
|
| 243 |
+
parser.add_argument("--checkpoint_dir", type=str,
|
| 244 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe",
|
| 245 |
+
help="检查点目录路径")
|
| 246 |
+
parser.add_argument("--compare", default=True,
|
| 247 |
+
help="比较两个特定检查点,而不是监控")
|
| 248 |
+
parser.add_argument("--ckpt1", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step1500_origin_cam_4.ckpt")
|
| 249 |
+
parser.add_argument("--ckpt2", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step500_origin_cam_4.ckpt")
|
| 250 |
+
parser.add_argument("--interval", type=int, default=60,
|
| 251 |
+
help="监控检查间隔(秒)")
|
| 252 |
+
parser.add_argument("--threshold", type=float, default=1e-8,
|
| 253 |
+
help="参数变化阈值")
|
| 254 |
+
|
| 255 |
+
args = parser.parse_args()
|
| 256 |
+
|
| 257 |
+
if args.compare:
|
| 258 |
+
if not args.ckpt1 or not args.ckpt2:
|
| 259 |
+
print("❌ 比较模式需要指定 --ckpt1 和 --ckpt2")
|
| 260 |
+
else:
|
| 261 |
+
compare_two_checkpoints(args.ckpt1, args.ckpt2)
|
| 262 |
+
else:
|
| 263 |
+
monitor_training(args.checkpoint_dir, args.interval)
|
scripts/decode_openx.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import argparse
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class VideoDecoder:
|
| 12 |
+
def __init__(self, vae_path, device="cuda"):
|
| 13 |
+
"""初始化视频解码器"""
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
# 初始化模型管理器
|
| 17 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 18 |
+
model_manager.load_models([vae_path])
|
| 19 |
+
|
| 20 |
+
# 创建pipeline并只保留VAE
|
| 21 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 22 |
+
self.pipe = self.pipe.to(device)
|
| 23 |
+
|
| 24 |
+
# 🔧 关键修复:确保VAE及其所有组件都在正确设备上
|
| 25 |
+
self.pipe.vae = self.pipe.vae.to(device)
|
| 26 |
+
if hasattr(self.pipe.vae, 'model'):
|
| 27 |
+
self.pipe.vae.model = self.pipe.vae.model.to(device)
|
| 28 |
+
|
| 29 |
+
print(f"✅ VAE解码器初始化完成,设备: {device}")
|
| 30 |
+
|
| 31 |
+
def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 32 |
+
"""
|
| 33 |
+
将latents解码为视频 - 修正版本,修复维度处理问题
|
| 34 |
+
"""
|
| 35 |
+
print(f"🔧 开始解码latents...")
|
| 36 |
+
print(f"输入latents形状: {latents.shape}")
|
| 37 |
+
print(f"输入latents设备: {latents.device}")
|
| 38 |
+
print(f"输入latents数据类型: {latents.dtype}")
|
| 39 |
+
|
| 40 |
+
# 确保latents有batch维度
|
| 41 |
+
if len(latents.shape) == 4: # [C, T, H, W]
|
| 42 |
+
latents = latents.unsqueeze(0) # -> [1, C, T, H, W]
|
| 43 |
+
|
| 44 |
+
# 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
|
| 45 |
+
model_dtype = next(self.pipe.vae.parameters()).dtype
|
| 46 |
+
model_device = next(self.pipe.vae.parameters()).device
|
| 47 |
+
|
| 48 |
+
print(f"模型设备: {model_device}")
|
| 49 |
+
print(f"模型数据类型: {model_dtype}")
|
| 50 |
+
|
| 51 |
+
# 将latents移动到正确的设备和数据类型
|
| 52 |
+
latents = latents.to(device=model_device, dtype=model_dtype)
|
| 53 |
+
|
| 54 |
+
print(f"解码latents形状: {latents.shape}")
|
| 55 |
+
print(f"解码latents设备: {latents.device}")
|
| 56 |
+
print(f"解码latents数据类型: {latents.dtype}")
|
| 57 |
+
|
| 58 |
+
# 🔧 强制设置pipeline设备,确保所有操作在同一设备上
|
| 59 |
+
self.pipe.device = model_device
|
| 60 |
+
|
| 61 |
+
# 使用VAE解码
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
try:
|
| 64 |
+
if tiled:
|
| 65 |
+
print("🔧 尝试tiled解码...")
|
| 66 |
+
decoded_video = self.pipe.decode_video(
|
| 67 |
+
latents,
|
| 68 |
+
tiled=True,
|
| 69 |
+
tile_size=tile_size,
|
| 70 |
+
tile_stride=tile_stride
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
print("🔧 使用非tiled解码...")
|
| 74 |
+
decoded_video = self.pipe.decode_video(latents, tiled=False)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"decode_video失败,错误: {e}")
|
| 78 |
+
import traceback
|
| 79 |
+
traceback.print_exc()
|
| 80 |
+
|
| 81 |
+
# 🔧 fallback: 尝试直接调用VAE
|
| 82 |
+
try:
|
| 83 |
+
print("🔧 尝试直接调用VAE解码...")
|
| 84 |
+
decoded_video = self.pipe.vae.decode(
|
| 85 |
+
latents.squeeze(0), # 移除batch维度 [C, T, H, W]
|
| 86 |
+
device=model_device,
|
| 87 |
+
tiled=False
|
| 88 |
+
)
|
| 89 |
+
# 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
|
| 90 |
+
if len(decoded_video.shape) == 4: # [T, H, W, C]
|
| 91 |
+
decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C]
|
| 92 |
+
except Exception as e2:
|
| 93 |
+
print(f"直接VAE解码也失败: {e2}")
|
| 94 |
+
raise e2
|
| 95 |
+
|
| 96 |
+
print(f"解码后视频形状: {decoded_video.shape}")
|
| 97 |
+
|
| 98 |
+
# 🔧 关键修正:正确处理维度顺序
|
| 99 |
+
video_np = None
|
| 100 |
+
|
| 101 |
+
if len(decoded_video.shape) == 5:
|
| 102 |
+
# 检查不同的可能维度顺序
|
| 103 |
+
if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
|
| 104 |
+
# 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
|
| 105 |
+
print("🔧 检测到格式: [B, C, T, H, W]")
|
| 106 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 107 |
+
elif decoded_video.shape[1] == 3:
|
| 108 |
+
# 如果第二个维度是3,可能是 [B, C, T, H, W]
|
| 109 |
+
print("🔧 检测到可能的格式: [B, C, T, H, W]")
|
| 110 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 111 |
+
elif decoded_video.shape[-1] == 3:
|
| 112 |
+
# 如果最后一个维度是3,可能是 [B, T, H, W, C]
|
| 113 |
+
print("🔧 检测到格式: [B, T, H, W, C]")
|
| 114 |
+
video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 115 |
+
else:
|
| 116 |
+
# 尝试找到维度为3的位置
|
| 117 |
+
shape = list(decoded_video.shape)
|
| 118 |
+
if 3 in shape:
|
| 119 |
+
channel_dim = shape.index(3)
|
| 120 |
+
print(f"🔧 检测到通道维度在位置: {channel_dim}")
|
| 121 |
+
|
| 122 |
+
if channel_dim == 1: # [B, C, T, H, W]
|
| 123 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
|
| 124 |
+
elif channel_dim == 4: # [B, T, H, W, C]
|
| 125 |
+
video_np = decoded_video[0].to(torch.float32).cpu().numpy()
|
| 126 |
+
else:
|
| 127 |
+
print(f"⚠️ 未知的通道维度位置: {channel_dim}")
|
| 128 |
+
raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
|
| 129 |
+
else:
|
| 130 |
+
print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
|
| 131 |
+
raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
|
| 132 |
+
|
| 133 |
+
elif len(decoded_video.shape) == 4:
|
| 134 |
+
# 4维张量,检查可能的格式
|
| 135 |
+
if decoded_video.shape[-1] == 3: # [T, H, W, C]
|
| 136 |
+
video_np = decoded_video.to(torch.float32).cpu().numpy()
|
| 137 |
+
elif decoded_video.shape[0] == 3: # [C, T, H, W]
|
| 138 |
+
video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
|
| 139 |
+
else:
|
| 140 |
+
print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
|
| 141 |
+
raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
|
| 142 |
+
else:
|
| 143 |
+
print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
|
| 144 |
+
raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
|
| 145 |
+
|
| 146 |
+
if video_np is None:
|
| 147 |
+
raise ValueError("Failed to convert video tensor to numpy array")
|
| 148 |
+
|
| 149 |
+
print(f"转换后视频数组形状: {video_np.shape}")
|
| 150 |
+
|
| 151 |
+
# 🔧 验证最终形状
|
| 152 |
+
if len(video_np.shape) != 4:
|
| 153 |
+
raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
|
| 154 |
+
|
| 155 |
+
if video_np.shape[-1] != 3:
|
| 156 |
+
print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
|
| 157 |
+
print(f"完整形状: {video_np.shape}")
|
| 158 |
+
# 尝试其他维度排列
|
| 159 |
+
if video_np.shape[0] == 3: # [C, T, H, W]
|
| 160 |
+
print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
|
| 161 |
+
video_np = np.transpose(video_np, (1, 2, 3, 0))
|
| 162 |
+
elif video_np.shape[1] == 3: # [T, C, H, W]
|
| 163 |
+
print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
|
| 164 |
+
video_np = np.transpose(video_np, (0, 2, 3, 1))
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
|
| 167 |
+
|
| 168 |
+
# 反归一化
|
| 169 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化
|
| 170 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
print(f"最终视频数组形状: {video_np.shape}")
|
| 173 |
+
print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
|
| 174 |
+
|
| 175 |
+
# 保存视频
|
| 176 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
with imageio.get_writer(output_path, fps=10, quality=8) as writer:
|
| 180 |
+
for frame_idx, frame in enumerate(video_np):
|
| 181 |
+
# 🔧 验证每一帧的形状
|
| 182 |
+
if len(frame.shape) != 3 or frame.shape[-1] != 3:
|
| 183 |
+
print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
writer.append_data(frame)
|
| 187 |
+
if frame_idx % 10 == 0:
|
| 188 |
+
print(f" 写入帧 {frame_idx}/{len(video_np)}")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"保存视频失败: {e}")
|
| 191 |
+
# 🔧 尝试保存前几帧为图片进行调试
|
| 192 |
+
debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
|
| 193 |
+
os.makedirs(debug_dir, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
for i in range(min(5, len(video_np))):
|
| 196 |
+
frame = video_np[i]
|
| 197 |
+
debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
|
| 198 |
+
try:
|
| 199 |
+
if len(frame.shape) == 3 and frame.shape[-1] == 3:
|
| 200 |
+
Image.fromarray(frame).save(debug_path)
|
| 201 |
+
print(f"调试: 保存帧 {i} 到 {debug_path}")
|
| 202 |
+
else:
|
| 203 |
+
print(f"调试: 帧 {i} 形状异常: {frame.shape}")
|
| 204 |
+
except Exception as e2:
|
| 205 |
+
print(f"调试: 保存帧 {i} 失败: {e2}")
|
| 206 |
+
raise e
|
| 207 |
+
|
| 208 |
+
print(f"✅ 视频保存到: {output_path}")
|
| 209 |
+
return video_np
|
| 210 |
+
|
| 211 |
+
def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
|
| 212 |
+
"""将视频帧保存为单独的图像文件"""
|
| 213 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
for i, frame in enumerate(video_np):
|
| 216 |
+
frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
|
| 217 |
+
# 🔧 验证帧形状
|
| 218 |
+
if len(frame.shape) == 3 and frame.shape[-1] == 3:
|
| 219 |
+
Image.fromarray(frame).save(frame_path)
|
| 220 |
+
else:
|
| 221 |
+
print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
|
| 222 |
+
|
| 223 |
+
print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")
|
| 224 |
+
|
| 225 |
+
def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
|
| 226 |
+
"""解码单个episode的编码数据 - 修正版本"""
|
| 227 |
+
print(f"\n🔧 解码episode: {encoded_pth_path}")
|
| 228 |
+
|
| 229 |
+
# 加载编码数据
|
| 230 |
+
try:
|
| 231 |
+
encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
|
| 232 |
+
print(f"✅ 成功加载编码数据")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"❌ 加载编码数据失败: {e}")
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
# 检查数据结构
|
| 238 |
+
print("🔍 编码数据结构:")
|
| 239 |
+
for key, value in encoded_data.items():
|
| 240 |
+
if isinstance(value, torch.Tensor):
|
| 241 |
+
print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
|
| 242 |
+
elif isinstance(value, dict):
|
| 243 |
+
print(f" - {key}: dict with keys {list(value.keys())}")
|
| 244 |
+
else:
|
| 245 |
+
print(f" - {key}: {type(value)}")
|
| 246 |
+
|
| 247 |
+
# 获取latents
|
| 248 |
+
latents = encoded_data.get('latents')
|
| 249 |
+
if latents is None:
|
| 250 |
+
print("❌ 未找到latents数据")
|
| 251 |
+
return False
|
| 252 |
+
|
| 253 |
+
# 🔧 确保latents在CPU上(加载时的默认状态)
|
| 254 |
+
if latents.device != torch.device('cpu'):
|
| 255 |
+
latents = latents.cpu()
|
| 256 |
+
print(f"🔧 将latents移动到CPU: {latents.device}")
|
| 257 |
+
|
| 258 |
+
episode_info = encoded_data.get('episode_info', {})
|
| 259 |
+
episode_idx = episode_info.get('episode_idx', 'unknown')
|
| 260 |
+
total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数
|
| 261 |
+
|
| 262 |
+
print(f"Episode信息:")
|
| 263 |
+
print(f" - Episode索引: {episode_idx}")
|
| 264 |
+
print(f" - Latents形状: {latents.shape}")
|
| 265 |
+
print(f" - Latents设备: {latents.device}")
|
| 266 |
+
print(f" - Latents数据类型: {latents.dtype}")
|
| 267 |
+
print(f" - 原始总帧数: {total_frames}")
|
| 268 |
+
print(f" - 压缩后帧数: {latents.shape[1]}")
|
| 269 |
+
|
| 270 |
+
# 创建输出目录
|
| 271 |
+
episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
|
| 272 |
+
output_dir = os.path.join(output_base_dir, episode_name)
|
| 273 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 274 |
+
|
| 275 |
+
# 初始化解码器
|
| 276 |
+
try:
|
| 277 |
+
decoder = VideoDecoder(vae_path, device)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"❌ 初始化解码器失败: {e}")
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
# 解码为视频
|
| 283 |
+
video_output_path = os.path.join(output_dir, "decoded_video.mp4")
|
| 284 |
+
try:
|
| 285 |
+
video_np = decoder.decode_latents_to_video(
|
| 286 |
+
latents,
|
| 287 |
+
video_output_path,
|
| 288 |
+
tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性
|
| 289 |
+
tile_size=(34, 34),
|
| 290 |
+
tile_stride=(18, 16)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# 保存前几帧为图像(用于快速检查)
|
| 294 |
+
frames_dir = os.path.join(output_dir, "frames")
|
| 295 |
+
sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧
|
| 296 |
+
decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
|
| 297 |
+
|
| 298 |
+
# 保存解码信息
|
| 299 |
+
decode_info = {
|
| 300 |
+
"source_pth": encoded_pth_path,
|
| 301 |
+
"decoded_video_path": video_output_path,
|
| 302 |
+
"latents_shape": list(latents.shape),
|
| 303 |
+
"decoded_video_shape": list(video_np.shape),
|
| 304 |
+
"original_total_frames": total_frames,
|
| 305 |
+
"decoded_frames": len(video_np),
|
| 306 |
+
"compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
|
| 307 |
+
"latents_dtype": str(latents.dtype),
|
| 308 |
+
"latents_device": str(latents.device),
|
| 309 |
+
"vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
info_path = os.path.join(output_dir, "decode_info.json")
|
| 313 |
+
with open(info_path, 'w') as f:
|
| 314 |
+
json.dump(decode_info, f, indent=2)
|
| 315 |
+
|
| 316 |
+
print(f"✅ Episode {episode_idx} 解码完成")
|
| 317 |
+
print(f" - 原始帧数: {total_frames}")
|
| 318 |
+
print(f" - 解码帧数: {len(video_np)}")
|
| 319 |
+
print(f" - 压缩比: {decode_info['compression_ratio']:.2f}")
|
| 320 |
+
print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
|
| 321 |
+
return True
|
| 322 |
+
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"❌ 解码失败: {e}")
|
| 325 |
+
import traceback
|
| 326 |
+
traceback.print_exc()
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
|
| 330 |
+
"""批量解码episodes"""
|
| 331 |
+
print(f"🔧 批量解码Open-X episodes")
|
| 332 |
+
print(f"源目录: {encoded_base_dir}")
|
| 333 |
+
print(f"输出目录: {output_base_dir}")
|
| 334 |
+
|
| 335 |
+
# 查找所有编码的episodes
|
| 336 |
+
episode_dirs = []
|
| 337 |
+
if os.path.exists(encoded_base_dir):
|
| 338 |
+
for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性
|
| 339 |
+
episode_dir = os.path.join(encoded_base_dir, item)
|
| 340 |
+
if os.path.isdir(episode_dir):
|
| 341 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 342 |
+
if os.path.exists(encoded_path):
|
| 343 |
+
episode_dirs.append(encoded_path)
|
| 344 |
+
|
| 345 |
+
print(f"找到 {len(episode_dirs)} 个编码的episodes")
|
| 346 |
+
|
| 347 |
+
if max_episodes and len(episode_dirs) > max_episodes:
|
| 348 |
+
episode_dirs = episode_dirs[:max_episodes]
|
| 349 |
+
print(f"限制处理前 {max_episodes} 个episodes")
|
| 350 |
+
|
| 351 |
+
# 批量解码
|
| 352 |
+
success_count = 0
|
| 353 |
+
for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
|
| 354 |
+
print(f"\n{'='*60}")
|
| 355 |
+
print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
|
| 356 |
+
|
| 357 |
+
success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
|
| 358 |
+
if success:
|
| 359 |
+
success_count += 1
|
| 360 |
+
|
| 361 |
+
print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
|
| 362 |
+
|
| 363 |
+
print(f"\n🎉 批量解码完成!")
|
| 364 |
+
print(f"总处理: {len(episode_dirs)} 个episodes")
|
| 365 |
+
print(f"成功解码: {success_count} 个episodes")
|
| 366 |
+
print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")
|
| 367 |
+
|
| 368 |
+
def main():
|
| 369 |
+
parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
|
| 370 |
+
parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
|
| 371 |
+
help="解码模式:single (单个episode) 或 batch (批量)")
|
| 372 |
+
parser.add_argument("--encoded_pth", type=str,
|
| 373 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
|
| 374 |
+
help="单个编码文件路径(single模式)")
|
| 375 |
+
parser.add_argument("--encoded_base_dir", type=str,
|
| 376 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
|
| 377 |
+
help="编码数据基础目录(batch模式)")
|
| 378 |
+
parser.add_argument("--vae_path", type=str,
|
| 379 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 380 |
+
help="VAE模型路径")
|
| 381 |
+
parser.add_argument("--output_dir", type=str,
|
| 382 |
+
default="./decoded_results_fixed",
|
| 383 |
+
help="解码输出目录")
|
| 384 |
+
parser.add_argument("--max_episodes", type=int, default=5,
|
| 385 |
+
help="最大解码episodes数量(batch模式,用于测试)")
|
| 386 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 387 |
+
help="计算设备")
|
| 388 |
+
|
| 389 |
+
args = parser.parse_args()
|
| 390 |
+
|
| 391 |
+
print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
|
| 392 |
+
print(f"模式: {args.mode}")
|
| 393 |
+
print(f"VAE路径: {args.vae_path}")
|
| 394 |
+
print(f"输出目录: {args.output_dir}")
|
| 395 |
+
print(f"设备: {args.device}")
|
| 396 |
+
|
| 397 |
+
# 🔧 检查CUDA可用性
|
| 398 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 399 |
+
print("⚠️ CUDA不可用,切换到CPU")
|
| 400 |
+
args.device = "cpu"
|
| 401 |
+
|
| 402 |
+
# 确保输出目录存在
|
| 403 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 404 |
+
|
| 405 |
+
if args.mode == "single":
|
| 406 |
+
print(f"输入文件: {args.encoded_pth}")
|
| 407 |
+
if not os.path.exists(args.encoded_pth):
|
| 408 |
+
print(f"❌ 输入文件不存在: {args.encoded_pth}")
|
| 409 |
+
return
|
| 410 |
+
|
| 411 |
+
success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
|
| 412 |
+
if success:
|
| 413 |
+
print("✅ 单个episode解码成功")
|
| 414 |
+
else:
|
| 415 |
+
print("❌ 单个episode解码失败")
|
| 416 |
+
|
| 417 |
+
elif args.mode == "batch":
|
| 418 |
+
print(f"输入目录: {args.encoded_base_dir}")
|
| 419 |
+
print(f"最大episodes: {args.max_episodes}")
|
| 420 |
+
|
| 421 |
+
if not os.path.exists(args.encoded_base_dir):
|
| 422 |
+
print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
|
| 423 |
+
return
|
| 424 |
+
|
| 425 |
+
batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|
scripts/download_recam.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
snapshot_download(
|
| 4 |
+
repo_id="KwaiVGI/ReCamMaster-Wan2.1",
|
| 5 |
+
local_dir="models/ReCamMaster/checkpoints",
|
| 6 |
+
resume_download=True # 支持断点续传
|
| 7 |
+
)
|
scripts/download_wan2.1.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modelscope import snapshot_download
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Download models
|
| 5 |
+
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
scripts/encode_dynamic_videos.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
class VideoEncoder(pl.LightningModule):
|
| 13 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 14 |
+
super().__init__()
|
| 15 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 16 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 17 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 18 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 19 |
+
|
| 20 |
+
self.frame_process = v2.Compose([
|
| 21 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 22 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 23 |
+
v2.ToTensor(),
|
| 24 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 25 |
+
])
|
| 26 |
+
|
| 27 |
+
def crop_and_resize(self, image):
|
| 28 |
+
width, height = image.size
|
| 29 |
+
width_ori, height_ori_ = 832 , 480
|
| 30 |
+
image = v2.functional.resize(
|
| 31 |
+
image,
|
| 32 |
+
(round(height_ori_), round(width_ori)),
|
| 33 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 34 |
+
)
|
| 35 |
+
return image
|
| 36 |
+
|
| 37 |
+
def load_video_frames(self, video_path):
|
| 38 |
+
"""加载完整视频"""
|
| 39 |
+
reader = imageio.get_reader(video_path)
|
| 40 |
+
frames = []
|
| 41 |
+
|
| 42 |
+
for frame_data in reader:
|
| 43 |
+
frame = Image.fromarray(frame_data)
|
| 44 |
+
frame = self.crop_and_resize(frame)
|
| 45 |
+
frame = self.frame_process(frame)
|
| 46 |
+
frames.append(frame)
|
| 47 |
+
|
| 48 |
+
reader.close()
|
| 49 |
+
|
| 50 |
+
if len(frames) == 0:
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
frames = torch.stack(frames, dim=0)
|
| 54 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 55 |
+
return frames
|
| 56 |
+
|
| 57 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path):
|
| 58 |
+
"""编码所有场景的视频"""
|
| 59 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 60 |
+
encoder = encoder.cuda()
|
| 61 |
+
encoder.pipe.device = "cuda"
|
| 62 |
+
|
| 63 |
+
processed_count = 0
|
| 64 |
+
|
| 65 |
+
for idx, scene_name in enumerate(tqdm(os.listdir(scenes_path))):
|
| 66 |
+
if idx < 450:
|
| 67 |
+
continue
|
| 68 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 69 |
+
if not os.path.isdir(scene_dir):
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# 检查是否已编码
|
| 73 |
+
encoded_path = os.path.join(scene_dir, "encoded_video-480p-1.pth")
|
| 74 |
+
if os.path.exists(encoded_path):
|
| 75 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# 加载场景信息
|
| 79 |
+
scene_info_path = os.path.join(scene_dir, "scene_info.json")
|
| 80 |
+
if not os.path.exists(scene_info_path):
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
with open(scene_info_path, 'r') as f:
|
| 84 |
+
scene_info = json.load(f)
|
| 85 |
+
|
| 86 |
+
# 加载视频
|
| 87 |
+
video_path = os.path.join(scene_dir, scene_info['video_path'])
|
| 88 |
+
if not os.path.exists(video_path):
|
| 89 |
+
print(f"Video not found: {video_path}")
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
print(f"Encoding scene {scene_name}...")
|
| 94 |
+
|
| 95 |
+
# 加载和编码视频
|
| 96 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 97 |
+
if video_frames is None:
|
| 98 |
+
print(f"Failed to load video: {video_path}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 102 |
+
|
| 103 |
+
# 编码视频
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 106 |
+
# print(latents.shape)
|
| 107 |
+
# assert False
|
| 108 |
+
# 编码文本
|
| 109 |
+
# prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
|
| 110 |
+
if processed_count == 0:
|
| 111 |
+
print('encode prompt!!!')
|
| 112 |
+
prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
|
| 113 |
+
del encoder.pipe.prompter
|
| 114 |
+
|
| 115 |
+
# 保存编码结果
|
| 116 |
+
encoded_data = {
|
| 117 |
+
"latents": latents.cpu(),
|
| 118 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 119 |
+
"image_emb": {}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
torch.save(encoded_data, encoded_path)
|
| 123 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 124 |
+
processed_count += 1
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error encoding scene {scene_name}: {e}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
parser = argparse.ArgumentParser()
|
| 134 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes")
|
| 135 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 136 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 137 |
+
parser.add_argument("--vae_path", type=str,
|
| 138 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path)
|
scripts/encode_openx.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
# 🔧 关键修复:设置环境变量避免GCS连接
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
os.environ["TFDS_DISABLE_GCS"] = "1"
|
| 17 |
+
|
| 18 |
+
import tensorflow_datasets as tfds
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
|
| 21 |
+
class VideoEncoder(pl.LightningModule):
|
| 22 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 23 |
+
super().__init__()
|
| 24 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 25 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 26 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 27 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 28 |
+
|
| 29 |
+
self.frame_process = v2.Compose([
|
| 30 |
+
v2.ToTensor(),
|
| 31 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
def crop_and_resize(self, image, target_width=832, target_height=480):
|
| 35 |
+
"""调整图像尺寸"""
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(target_height, target_width),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_episode_frames(self, episode_data, max_frames=300):
|
| 44 |
+
"""🔧 从fractal数据集加载视频帧 - 基于实际observation字段优化"""
|
| 45 |
+
frames = []
|
| 46 |
+
|
| 47 |
+
steps = episode_data['steps']
|
| 48 |
+
frame_count = 0
|
| 49 |
+
|
| 50 |
+
print(f"开始提取帧,最多 {max_frames} 帧...")
|
| 51 |
+
|
| 52 |
+
for step_idx, step in enumerate(steps):
|
| 53 |
+
if frame_count >= max_frames:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
obs = step['observation']
|
| 58 |
+
|
| 59 |
+
# 🔧 基于实际的observation字段,优先使用'image'
|
| 60 |
+
img_data = None
|
| 61 |
+
image_keys_to_try = [
|
| 62 |
+
'image', # ✅ 确认存在的主要图像字段
|
| 63 |
+
'rgb', # 备用RGB图像
|
| 64 |
+
'camera_image', # 备用相机图像
|
| 65 |
+
'exterior_image_1_left', # 可能的外部摄像头
|
| 66 |
+
'wrist_image', # 可能的手腕摄像头
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for img_key in image_keys_to_try:
|
| 70 |
+
if img_key in obs:
|
| 71 |
+
try:
|
| 72 |
+
img_tensor = obs[img_key]
|
| 73 |
+
img_data = img_tensor.numpy()
|
| 74 |
+
if step_idx < 3: # 只为前几个步骤打印
|
| 75 |
+
print(f"✅ 找到图像字段: {img_key}, 形状: {img_data.shape}")
|
| 76 |
+
break
|
| 77 |
+
except Exception as e:
|
| 78 |
+
if step_idx < 3:
|
| 79 |
+
print(f"尝试字段 {img_key} 失败: {e}")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if img_data is not None:
|
| 83 |
+
# 确保图像数据格式正确
|
| 84 |
+
if len(img_data.shape) == 3: # [H, W, C]
|
| 85 |
+
if img_data.dtype == np.uint8:
|
| 86 |
+
frame = Image.fromarray(img_data)
|
| 87 |
+
else:
|
| 88 |
+
# 如果是归一化的浮点数,转换为uint8
|
| 89 |
+
if img_data.max() <= 1.0:
|
| 90 |
+
img_data = (img_data * 255).astype(np.uint8)
|
| 91 |
+
else:
|
| 92 |
+
img_data = img_data.astype(np.uint8)
|
| 93 |
+
frame = Image.fromarray(img_data)
|
| 94 |
+
|
| 95 |
+
# 转换为RGB如果需要
|
| 96 |
+
if frame.mode != 'RGB':
|
| 97 |
+
frame = frame.convert('RGB')
|
| 98 |
+
|
| 99 |
+
frame = self.crop_and_resize(frame)
|
| 100 |
+
frame = self.frame_process(frame)
|
| 101 |
+
frames.append(frame)
|
| 102 |
+
frame_count += 1
|
| 103 |
+
|
| 104 |
+
if frame_count % 50 == 0:
|
| 105 |
+
print(f"已处理 {frame_count} 帧")
|
| 106 |
+
else:
|
| 107 |
+
if step_idx < 5:
|
| 108 |
+
print(f"步骤 {step_idx}: 图像形状不正确 {img_data.shape}")
|
| 109 |
+
else:
|
| 110 |
+
# 如果找不到图像,打印可用的观测键
|
| 111 |
+
if step_idx < 5: # 只为前几个步骤打印
|
| 112 |
+
available_keys = list(obs.keys())
|
| 113 |
+
print(f"步骤 {step_idx}: 未找到图像,可用键: {available_keys}")
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"处理步骤 {step_idx} 时出错: {e}")
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
print(f"成功提取 {len(frames)} 帧")
|
| 120 |
+
|
| 121 |
+
if len(frames) == 0:
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
frames = torch.stack(frames, dim=0)
|
| 125 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 126 |
+
return frames
|
| 127 |
+
|
| 128 |
+
def extract_camera_poses(self, episode_data, num_frames):
|
| 129 |
+
"""🔧 从fractal数据集提取相机位姿信息 - 基于实际observation和action字段优化"""
|
| 130 |
+
camera_poses = []
|
| 131 |
+
|
| 132 |
+
steps = episode_data['steps']
|
| 133 |
+
frame_count = 0
|
| 134 |
+
|
| 135 |
+
print("提取相机位姿信息...")
|
| 136 |
+
|
| 137 |
+
# 🔧 累积位姿信息
|
| 138 |
+
cumulative_translation = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 139 |
+
cumulative_rotation = np.array([0.0, 0.0, 0.0], dtype=np.float32) # 欧拉角
|
| 140 |
+
|
| 141 |
+
for step_idx, step in enumerate(steps):
|
| 142 |
+
if frame_count >= num_frames:
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
obs = step['observation']
|
| 147 |
+
action = step.get('action', {})
|
| 148 |
+
|
| 149 |
+
# 🔧 基于实际的字段提取位姿变化
|
| 150 |
+
pose_data = {}
|
| 151 |
+
found_pose = False
|
| 152 |
+
|
| 153 |
+
# 1. 优先使用action中的world_vector(世界坐标系中的位移)
|
| 154 |
+
if 'world_vector' in action:
|
| 155 |
+
try:
|
| 156 |
+
world_vector = action['world_vector'].numpy()
|
| 157 |
+
if len(world_vector) == 3:
|
| 158 |
+
# 累积世界坐标位移
|
| 159 |
+
cumulative_translation += world_vector
|
| 160 |
+
pose_data['translation'] = cumulative_translation.copy()
|
| 161 |
+
found_pose = True
|
| 162 |
+
|
| 163 |
+
if step_idx < 3:
|
| 164 |
+
print(f"使用action.world_vector: {world_vector}, 累积位移: {cumulative_translation}")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
if step_idx < 3:
|
| 167 |
+
print(f"action.world_vector提取失败: {e}")
|
| 168 |
+
|
| 169 |
+
# 2. 使用action中的rotation_delta(旋转变化)
|
| 170 |
+
if 'rotation_delta' in action:
|
| 171 |
+
try:
|
| 172 |
+
rotation_delta = action['rotation_delta'].numpy()
|
| 173 |
+
if len(rotation_delta) == 3:
|
| 174 |
+
# 累积旋转变化
|
| 175 |
+
cumulative_rotation += rotation_delta
|
| 176 |
+
|
| 177 |
+
# 转换为四元数(简化版本)
|
| 178 |
+
euler_angles = cumulative_rotation
|
| 179 |
+
# 欧拉角转四元数(ZYX顺序)
|
| 180 |
+
roll, pitch, yaw = euler_angles[0], euler_angles[1], euler_angles[2]
|
| 181 |
+
|
| 182 |
+
# 简化的欧拉角到四元数转换
|
| 183 |
+
cy = np.cos(yaw * 0.5)
|
| 184 |
+
sy = np.sin(yaw * 0.5)
|
| 185 |
+
cp = np.cos(pitch * 0.5)
|
| 186 |
+
sp = np.sin(pitch * 0.5)
|
| 187 |
+
cr = np.cos(roll * 0.5)
|
| 188 |
+
sr = np.sin(roll * 0.5)
|
| 189 |
+
|
| 190 |
+
qw = cr * cp * cy + sr * sp * sy
|
| 191 |
+
qx = sr * cp * cy - cr * sp * sy
|
| 192 |
+
qy = cr * sp * cy + sr * cp * sy
|
| 193 |
+
qz = cr * cp * sy - sr * sp * cy
|
| 194 |
+
|
| 195 |
+
pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
|
| 196 |
+
found_pose = True
|
| 197 |
+
|
| 198 |
+
if step_idx < 3:
|
| 199 |
+
print(f"使用action.rotation_delta: {rotation_delta}, 累积旋转: {cumulative_rotation}")
|
| 200 |
+
except Exception as e:
|
| 201 |
+
if step_idx < 3:
|
| 202 |
+
print(f"action.rotation_delta提取失败: {e}")
|
| 203 |
+
|
| 204 |
+
# 确保rotation字段存在
|
| 205 |
+
if 'rotation' not in pose_data:
|
| 206 |
+
# 使用当前累积的旋转计算四元数
|
| 207 |
+
roll, pitch, yaw = cumulative_rotation[0], cumulative_rotation[1], cumulative_rotation[2]
|
| 208 |
+
|
| 209 |
+
cy = np.cos(yaw * 0.5)
|
| 210 |
+
sy = np.sin(yaw * 0.5)
|
| 211 |
+
cp = np.cos(pitch * 0.5)
|
| 212 |
+
sp = np.sin(pitch * 0.5)
|
| 213 |
+
cr = np.cos(roll * 0.5)
|
| 214 |
+
sr = np.sin(roll * 0.5)
|
| 215 |
+
|
| 216 |
+
qw = cr * cp * cy + sr * sp * sy
|
| 217 |
+
qx = sr * cp * cy - cr * sp * sy
|
| 218 |
+
qy = cr * sp * cy + sr * cp * sy
|
| 219 |
+
qz = cr * cp * sy - sr * sp * cy
|
| 220 |
+
|
| 221 |
+
pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
|
| 222 |
+
|
| 223 |
+
camera_poses.append(pose_data)
|
| 224 |
+
frame_count += 1
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"提取位姿步骤 {step_idx} 时出错: {e}")
|
| 228 |
+
# 添加默认位姿
|
| 229 |
+
pose_data = {
|
| 230 |
+
'translation': cumulative_translation.copy(),
|
| 231 |
+
'rotation': np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
| 232 |
+
}
|
| 233 |
+
camera_poses.append(pose_data)
|
| 234 |
+
frame_count += 1
|
| 235 |
+
|
| 236 |
+
print(f"提取了 {len(camera_poses)} 个位姿")
|
| 237 |
+
print(f"最终累积位移: {cumulative_translation}")
|
| 238 |
+
print(f"最终累积旋转: {cumulative_rotation}")
|
| 239 |
+
|
| 240 |
+
return camera_poses
|
| 241 |
+
|
| 242 |
+
def create_camera_matrices(self, camera_poses):
|
| 243 |
+
"""将位姿转换为4x4变换矩阵"""
|
| 244 |
+
matrices = []
|
| 245 |
+
|
| 246 |
+
for pose in camera_poses:
|
| 247 |
+
matrix = np.eye(4, dtype=np.float32)
|
| 248 |
+
|
| 249 |
+
# 设置平移
|
| 250 |
+
matrix[:3, 3] = pose['translation']
|
| 251 |
+
|
| 252 |
+
# 设置旋转 - 假设是四元数 [w, x, y, z]
|
| 253 |
+
if len(pose['rotation']) == 4:
|
| 254 |
+
# 四元数转旋转矩阵
|
| 255 |
+
q = pose['rotation']
|
| 256 |
+
w, x, y, z = q[0], q[1], q[2], q[3]
|
| 257 |
+
|
| 258 |
+
# 四元数到旋转矩阵的转换
|
| 259 |
+
matrix[0, 0] = 1 - 2*(y*y + z*z)
|
| 260 |
+
matrix[0, 1] = 2*(x*y - w*z)
|
| 261 |
+
matrix[0, 2] = 2*(x*z + w*y)
|
| 262 |
+
matrix[1, 0] = 2*(x*y + w*z)
|
| 263 |
+
matrix[1, 1] = 1 - 2*(x*x + z*z)
|
| 264 |
+
matrix[1, 2] = 2*(y*z - w*x)
|
| 265 |
+
matrix[2, 0] = 2*(x*z - w*y)
|
| 266 |
+
matrix[2, 1] = 2*(y*z + w*x)
|
| 267 |
+
matrix[2, 2] = 1 - 2*(x*x + y*y)
|
| 268 |
+
elif len(pose['rotation']) == 3:
|
| 269 |
+
# 欧拉角转换(如果需要)
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
matrices.append(matrix)
|
| 273 |
+
|
| 274 |
+
return np.array(matrices)
|
| 275 |
+
|
| 276 |
+
def encode_fractal_dataset(dataset_path, text_encoder_path, vae_path, output_dir, max_episodes=None):
|
| 277 |
+
"""🔧 编码fractal20220817_data数据集 - 基于实际字段结构优化"""
|
| 278 |
+
|
| 279 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 280 |
+
encoder = encoder.cuda()
|
| 281 |
+
encoder.pipe.device = "cuda"
|
| 282 |
+
|
| 283 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
processed_count = 0
|
| 286 |
+
prompt_emb = None
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# 🔧 使用你提供的成功方法加载数据集
|
| 290 |
+
ds = tfds.load(
|
| 291 |
+
"fractal20220817_data",
|
| 292 |
+
split="train",
|
| 293 |
+
data_dir=dataset_path,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
print(f"✅ 成功加载fractal20220817_data数据集")
|
| 297 |
+
|
| 298 |
+
# 限制处理的episode数量
|
| 299 |
+
if max_episodes:
|
| 300 |
+
ds = ds.take(max_episodes)
|
| 301 |
+
print(f"限制处理episodes数量: {max_episodes}")
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"❌ 加载数据集失败: {e}")
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
for episode_idx, episode in enumerate(tqdm(ds, desc="处理episodes")):
|
| 308 |
+
try:
|
| 309 |
+
episode_name = f"episode_{episode_idx:06d}"
|
| 310 |
+
save_episode_dir = os.path.join(output_dir, episode_name)
|
| 311 |
+
|
| 312 |
+
# 检查是否已经处理过
|
| 313 |
+
encoded_path = os.path.join(save_episode_dir, "encoded_video.pth")
|
| 314 |
+
if os.path.exists(encoded_path):
|
| 315 |
+
print(f"Episode {episode_name} 已处理,跳过...")
|
| 316 |
+
processed_count += 1
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
os.makedirs(save_episode_dir, exist_ok=True)
|
| 320 |
+
|
| 321 |
+
print(f"\n🔧 处理episode {episode_name}...")
|
| 322 |
+
|
| 323 |
+
# 🔧 分析episode结构(仅对前几个episode)
|
| 324 |
+
if episode_idx < 2:
|
| 325 |
+
print("Episode结构分析:")
|
| 326 |
+
for key in episode.keys():
|
| 327 |
+
print(f" - {key}: {type(episode[key])}")
|
| 328 |
+
|
| 329 |
+
# 分析第一个step的结构
|
| 330 |
+
steps = episode['steps']
|
| 331 |
+
for step in steps.take(1):
|
| 332 |
+
print("第一个step结构:")
|
| 333 |
+
for key in step.keys():
|
| 334 |
+
print(f" - {key}: {type(step[key])}")
|
| 335 |
+
|
| 336 |
+
if 'observation' in step:
|
| 337 |
+
obs = step['observation']
|
| 338 |
+
print(" observation键:")
|
| 339 |
+
print(f" 🔍 可用字段: {list(obs.keys())}")
|
| 340 |
+
|
| 341 |
+
# 重点检查图像和位姿相关字段
|
| 342 |
+
key_fields = ['image', 'vector_to_go', 'rotation_delta_to_go', 'base_pose_tool_reached']
|
| 343 |
+
for key in key_fields:
|
| 344 |
+
if key in obs:
|
| 345 |
+
try:
|
| 346 |
+
value = obs[key]
|
| 347 |
+
if hasattr(value, 'shape'):
|
| 348 |
+
print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
|
| 349 |
+
else:
|
| 350 |
+
print(f" ✅ {key}: {type(value)}")
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f" ❌ {key}: 无法访问 ({e})")
|
| 353 |
+
|
| 354 |
+
if 'action' in step:
|
| 355 |
+
action = step['action']
|
| 356 |
+
print(" action键:")
|
| 357 |
+
print(f" 🔍 可用字段: {list(action.keys())}")
|
| 358 |
+
|
| 359 |
+
# 重点检查位姿相关字段
|
| 360 |
+
key_fields = ['world_vector', 'rotation_delta', 'base_displacement_vector']
|
| 361 |
+
for key in key_fields:
|
| 362 |
+
if key in action:
|
| 363 |
+
try:
|
| 364 |
+
value = action[key]
|
| 365 |
+
if hasattr(value, 'shape'):
|
| 366 |
+
print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
|
| 367 |
+
else:
|
| 368 |
+
print(f" ✅ {key}: {type(value)}")
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f" ❌ {key}: 无法访问 ({e})")
|
| 371 |
+
|
| 372 |
+
# 加载视频帧
|
| 373 |
+
video_frames = encoder.load_episode_frames(episode)
|
| 374 |
+
if video_frames is None:
|
| 375 |
+
print(f"❌ 无法加载episode {episode_name}的视频帧")
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
print(f"✅ Episode {episode_name} 视频形状: {video_frames.shape}")
|
| 379 |
+
|
| 380 |
+
# 提取相机位姿
|
| 381 |
+
num_frames = video_frames.shape[1]
|
| 382 |
+
camera_poses = encoder.extract_camera_poses(episode, num_frames)
|
| 383 |
+
camera_matrices = encoder.create_camera_matrices(camera_poses)
|
| 384 |
+
|
| 385 |
+
print(f"🔧 编码episode {episode_name}...")
|
| 386 |
+
|
| 387 |
+
# 准备相机数据
|
| 388 |
+
cam_emb = {
|
| 389 |
+
'extrinsic': camera_matrices,
|
| 390 |
+
'intrinsic': np.eye(3, dtype=np.float32)
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
# 编码视频
|
| 394 |
+
frames_batch = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 395 |
+
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
latents = encoder.pipe.encode_video(frames_batch, **encoder.tiler_kwargs)[0]
|
| 398 |
+
|
| 399 |
+
# 编码文本prompt(第一次)
|
| 400 |
+
if prompt_emb is None:
|
| 401 |
+
print('🔧 编码prompt...')
|
| 402 |
+
prompt_emb = encoder.pipe.encode_prompt(
|
| 403 |
+
"A video of robotic manipulation task with camera movement"
|
| 404 |
+
)
|
| 405 |
+
# 释放prompter以节省内存
|
| 406 |
+
del encoder.pipe.prompter
|
| 407 |
+
|
| 408 |
+
# 保存编码结果
|
| 409 |
+
encoded_data = {
|
| 410 |
+
"latents": latents.cpu(),
|
| 411 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
| 412 |
+
for k, v in prompt_emb.items()},
|
| 413 |
+
"cam_emb": cam_emb,
|
| 414 |
+
"episode_info": {
|
| 415 |
+
"episode_idx": episode_idx,
|
| 416 |
+
"total_frames": video_frames.shape[1],
|
| 417 |
+
"pose_extraction_method": "observation_action_based"
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
torch.save(encoded_data, encoded_path)
|
| 422 |
+
print(f"✅ 保存编码数据: {encoded_path}")
|
| 423 |
+
|
| 424 |
+
processed_count += 1
|
| 425 |
+
print(f"✅ 已处理 {processed_count} 个episodes")
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"❌ 处理episode {episode_idx}时出错: {e}")
|
| 429 |
+
import traceback
|
| 430 |
+
traceback.print_exc()
|
| 431 |
+
continue
|
| 432 |
+
|
| 433 |
+
print(f"🎉 编码完成! 总共处理了 {processed_count} 个episodes")
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
parser = argparse.ArgumentParser(description="Encode Open-X Fractal20220817 Dataset - Based on Real Structure")
|
| 436 |
+
parser.add_argument("--dataset_path", type=str,
|
| 437 |
+
default="/share_zhuyixuan05/public_datasets/open-x/0.1.0",
|
| 438 |
+
help="Path to tensorflow_datasets directory")
|
| 439 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 440 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 441 |
+
parser.add_argument("--vae_path", type=str,
|
| 442 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 443 |
+
parser.add_argument("--output_dir", type=str,
|
| 444 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded")
|
| 445 |
+
parser.add_argument("--max_episodes", type=int, default=10000,
|
| 446 |
+
help="Maximum number of episodes to process (default: 10 for testing)")
|
| 447 |
+
|
| 448 |
+
args = parser.parse_args()
|
| 449 |
+
|
| 450 |
+
# 确保输出目录存在
|
| 451 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 452 |
+
|
| 453 |
+
print("🚀 开始编码Open-X Fractal数据集 (基于实际字段结构)...")
|
| 454 |
+
print(f"📁 数据集路径: {args.dataset_path}")
|
| 455 |
+
print(f"💾 输出目录: {args.output_dir}")
|
| 456 |
+
print(f"🔢 最大处理episodes: {args.max_episodes}")
|
| 457 |
+
print("🔧 基于实际observation和action字段的位姿提取方法")
|
| 458 |
+
print("✅ 优先使用 'image' 字段获取图像数据")
|
| 459 |
+
|
| 460 |
+
encode_fractal_dataset(
|
| 461 |
+
args.dataset_path,
|
| 462 |
+
args.text_encoder_path,
|
| 463 |
+
args.vae_path,
|
| 464 |
+
args.output_dir,
|
| 465 |
+
args.max_episodes
|
| 466 |
+
)
|
scripts/encode_rlbench_video.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 512 , 512
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 76 |
+
# if i < 1700:
|
| 77 |
+
# continue
|
| 78 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 79 |
+
for j, demo_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 80 |
+
demo_dir = os.path.join(scene_dir, demo_name)
|
| 81 |
+
for filename in os.listdir(demo_dir):
|
| 82 |
+
# 检查文件是否以.mp4结尾(不区分大小写)
|
| 83 |
+
if filename.lower().endswith('.mp4'):
|
| 84 |
+
# 获取完整路径
|
| 85 |
+
full_path = os.path.join(demo_dir, filename)
|
| 86 |
+
print(full_path)
|
| 87 |
+
save_dir = os.path.join(output_dir,scene_name+'_'+demo_name)
|
| 88 |
+
# print('in:',scene_dir)
|
| 89 |
+
# print('out:',save_dir)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 94 |
+
# 检查是否已编码
|
| 95 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 96 |
+
if os.path.exists(encoded_path):
|
| 97 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
# 加载场景信息
|
| 101 |
+
|
| 102 |
+
scene_cam_path = full_path.replace("side.mp4", "data.npy")
|
| 103 |
+
print(scene_cam_path)
|
| 104 |
+
if not os.path.exists(scene_cam_path):
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# with np.load(scene_cam_path) as data:
|
| 108 |
+
cam_data = np.load(scene_cam_path)
|
| 109 |
+
cam_emb = cam_data
|
| 110 |
+
print(cam_data.shape)
|
| 111 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 112 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 113 |
+
|
| 114 |
+
# 加载视频
|
| 115 |
+
video_path = full_path
|
| 116 |
+
if not os.path.exists(video_path):
|
| 117 |
+
print(f"Video not found: {video_path}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# try:
|
| 121 |
+
print(f"Encoding scene {scene_name}...Demo {demo_name}")
|
| 122 |
+
|
| 123 |
+
# 加载和编码视频
|
| 124 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 125 |
+
if video_frames is None:
|
| 126 |
+
print(f"Failed to load video: {video_path}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 130 |
+
print('video shape:',video_frames.shape)
|
| 131 |
+
# 编码视频
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 134 |
+
|
| 135 |
+
# 编码文本
|
| 136 |
+
# if processed_count == 0:
|
| 137 |
+
# print('encode prompt!!!')
|
| 138 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 139 |
+
# del encoder.pipe.prompter
|
| 140 |
+
# pdb.set_trace()
|
| 141 |
+
# 保存编码结果
|
| 142 |
+
encoded_data = {
|
| 143 |
+
"latents": latents.cpu(),
|
| 144 |
+
#"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 145 |
+
"cam_emb": cam_emb
|
| 146 |
+
}
|
| 147 |
+
# pdb.set_trace()
|
| 148 |
+
torch.save(encoded_data, encoded_path)
|
| 149 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 150 |
+
processed_count += 1
|
| 151 |
+
|
| 152 |
+
# except Exception as e:
|
| 153 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 154 |
+
# continue
|
| 155 |
+
|
| 156 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
parser = argparse.ArgumentParser()
|
| 160 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/RLBench")
|
| 161 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 162 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 163 |
+
parser.add_argument("--vae_path", type=str,
|
| 164 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 165 |
+
|
| 166 |
+
parser.add_argument("--output_dir",type=str,
|
| 167 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 168 |
+
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_sekai_video.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 76 |
+
# if i < 1700:
|
| 77 |
+
# continue
|
| 78 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 79 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 80 |
+
# print('in:',scene_dir)
|
| 81 |
+
# print('out:',save_dir)
|
| 82 |
+
|
| 83 |
+
if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 88 |
+
# 检查是否已编码
|
| 89 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 90 |
+
if os.path.exists(encoded_path):
|
| 91 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# 加载场景信息
|
| 95 |
+
|
| 96 |
+
scene_cam_path = scene_dir.replace(".mp4", ".npz")
|
| 97 |
+
if not os.path.exists(scene_cam_path):
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
with np.load(scene_cam_path) as data:
|
| 101 |
+
cam_data = data.files
|
| 102 |
+
cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 103 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 104 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 105 |
+
|
| 106 |
+
# 加载视频
|
| 107 |
+
video_path = scene_dir
|
| 108 |
+
if not os.path.exists(video_path):
|
| 109 |
+
print(f"Video not found: {video_path}")
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# try:
|
| 113 |
+
print(f"Encoding scene {scene_name}...")
|
| 114 |
+
|
| 115 |
+
# 加载和编码视频
|
| 116 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 117 |
+
if video_frames is None:
|
| 118 |
+
print(f"Failed to load video: {video_path}")
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 122 |
+
print('video shape:',video_frames.shape)
|
| 123 |
+
# 编码视频
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 126 |
+
|
| 127 |
+
# 编码文本
|
| 128 |
+
if processed_count == 0:
|
| 129 |
+
print('encode prompt!!!')
|
| 130 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 131 |
+
del encoder.pipe.prompter
|
| 132 |
+
# pdb.set_trace()
|
| 133 |
+
# 保存编码结果
|
| 134 |
+
encoded_data = {
|
| 135 |
+
"latents": latents.cpu(),
|
| 136 |
+
#"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 137 |
+
"cam_emb": cam_emb
|
| 138 |
+
}
|
| 139 |
+
# pdb.set_trace()
|
| 140 |
+
torch.save(encoded_data, encoded_path)
|
| 141 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 142 |
+
processed_count += 1
|
| 143 |
+
|
| 144 |
+
# except Exception as e:
|
| 145 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 146 |
+
# continue
|
| 147 |
+
|
| 148 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
parser = argparse.ArgumentParser()
|
| 152 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
|
| 153 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 154 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 155 |
+
parser.add_argument("--vae_path", type=str,
|
| 156 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 157 |
+
|
| 158 |
+
parser.add_argument("--output_dir",type=str,
|
| 159 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 160 |
+
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_sekai_walking.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 17 |
+
|
| 18 |
+
class VideoEncoder(pl.LightningModule):
|
| 19 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 20 |
+
super().__init__()
|
| 21 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 22 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 23 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 24 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 25 |
+
|
| 26 |
+
self.frame_process = v2.Compose([
|
| 27 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 28 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 29 |
+
v2.ToTensor(),
|
| 30 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def crop_and_resize(self, image):
|
| 34 |
+
width, height = image.size
|
| 35 |
+
# print(width,height)
|
| 36 |
+
width_ori, height_ori_ = 832 , 480
|
| 37 |
+
image = v2.functional.resize(
|
| 38 |
+
image,
|
| 39 |
+
(round(height_ori_), round(width_ori)),
|
| 40 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 41 |
+
)
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
def load_video_frames(self, video_path):
|
| 45 |
+
"""加载完整视频"""
|
| 46 |
+
reader = imageio.get_reader(video_path)
|
| 47 |
+
frames = []
|
| 48 |
+
|
| 49 |
+
for frame_data in reader:
|
| 50 |
+
frame = Image.fromarray(frame_data)
|
| 51 |
+
frame = self.crop_and_resize(frame)
|
| 52 |
+
frame = self.frame_process(frame)
|
| 53 |
+
frames.append(frame)
|
| 54 |
+
|
| 55 |
+
reader.close()
|
| 56 |
+
|
| 57 |
+
if len(frames) == 0:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
frames = torch.stack(frames, dim=0)
|
| 61 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 62 |
+
return frames
|
| 63 |
+
|
| 64 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 65 |
+
"""编码所有场景的视频"""
|
| 66 |
+
|
| 67 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 68 |
+
encoder = encoder.cuda()
|
| 69 |
+
encoder.pipe.device = "cuda"
|
| 70 |
+
|
| 71 |
+
processed_count = 0
|
| 72 |
+
|
| 73 |
+
processed_chunk_count = 0
|
| 74 |
+
|
| 75 |
+
prompt_emb = 0
|
| 76 |
+
|
| 77 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 78 |
+
chunk_size = 300
|
| 79 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 80 |
+
# print('index-----:',type(i))
|
| 81 |
+
# if i < 3000 :#or i >=2000:
|
| 82 |
+
# # print('index-----:',i)
|
| 83 |
+
# continue
|
| 84 |
+
# print('index:',i)
|
| 85 |
+
print('index:',i)
|
| 86 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 87 |
+
|
| 88 |
+
# save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 89 |
+
# print('in:',scene_dir)
|
| 90 |
+
# print('out:',save_dir)
|
| 91 |
+
|
| 92 |
+
if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
scene_cam_path = scene_dir.replace(".mp4", ".npz")
|
| 97 |
+
if not os.path.exists(scene_cam_path):
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
with np.load(scene_cam_path) as data:
|
| 101 |
+
cam_data = data.files
|
| 102 |
+
cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 103 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 104 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 105 |
+
|
| 106 |
+
video_name = scene_name[:-4].split('_')[0]
|
| 107 |
+
start_frame = int(scene_name[:-4].split('_')[1])
|
| 108 |
+
end_frame = int(scene_name[:-4].split('_')[2])
|
| 109 |
+
|
| 110 |
+
sampled_range = range(start_frame, end_frame , chunk_size)
|
| 111 |
+
sampled_frames = list(sampled_range)
|
| 112 |
+
|
| 113 |
+
sampled_chunk_end = sampled_frames[0] + 300
|
| 114 |
+
start_str = f"{sampled_frames[0]:07d}"
|
| 115 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 116 |
+
|
| 117 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 118 |
+
save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
|
| 119 |
+
|
| 120 |
+
if os.path.exists(save_chunk_path):
|
| 121 |
+
print(f"Video {video_name} already encoded, skipping...")
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# 加载视频
|
| 125 |
+
video_path = scene_dir
|
| 126 |
+
if not os.path.exists(video_path):
|
| 127 |
+
print(f"Video not found: {video_path}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 131 |
+
if video_frames is None:
|
| 132 |
+
print(f"Failed to load video: {video_path}")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 136 |
+
print('video shape:',video_frames.shape)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# print(sampled_frames)
|
| 141 |
+
|
| 142 |
+
print(f"Encoding scene {scene_name}...")
|
| 143 |
+
for sampled_chunk_start in sampled_frames:
|
| 144 |
+
sampled_chunk_end = sampled_chunk_start + 300
|
| 145 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 146 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 147 |
+
|
| 148 |
+
# 生成保存目录名(假设video_name已定义)
|
| 149 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 150 |
+
save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 151 |
+
|
| 152 |
+
os.makedirs(save_chunk_dir,exist_ok=True)
|
| 153 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 154 |
+
|
| 155 |
+
encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
|
| 156 |
+
|
| 157 |
+
if os.path.exists(encoded_path):
|
| 158 |
+
print(f"Chunk {chunk_name} already encoded, skipping...")
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
|
| 163 |
+
# print('extrinsic:',cam_emb['extrinsic'].shape)
|
| 164 |
+
chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
|
| 165 |
+
'intrinsic':cam_emb['intrinsic']}
|
| 166 |
+
|
| 167 |
+
# print('chunk shape:',chunk_frames.shape)
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
|
| 171 |
+
|
| 172 |
+
# 编码文本
|
| 173 |
+
# if processed_count == 0:
|
| 174 |
+
# print('encode prompt!!!')
|
| 175 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 176 |
+
# del encoder.pipe.prompter
|
| 177 |
+
# pdb.set_trace()
|
| 178 |
+
# 保存编码结果
|
| 179 |
+
encoded_data = {
|
| 180 |
+
"latents": latents.cpu(),
|
| 181 |
+
# "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 182 |
+
"cam_emb": chunk_cam_emb
|
| 183 |
+
}
|
| 184 |
+
# pdb.set_trace()
|
| 185 |
+
torch.save(encoded_data, encoded_path)
|
| 186 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 187 |
+
processed_chunk_count += 1
|
| 188 |
+
|
| 189 |
+
processed_count += 1
|
| 190 |
+
|
| 191 |
+
print("Encoded scene numebr:",processed_count)
|
| 192 |
+
print("Encoded chunk numebr:",processed_chunk_count)
|
| 193 |
+
|
| 194 |
+
# os.makedirs(save_dir,exist_ok=True)
|
| 195 |
+
# # 检查是否已编码
|
| 196 |
+
# encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 197 |
+
# if os.path.exists(encoded_path):
|
| 198 |
+
# print(f"Scene {scene_name} already encoded, skipping...")
|
| 199 |
+
# continue
|
| 200 |
+
|
| 201 |
+
# 加载场景信息
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# try:
|
| 206 |
+
# print(f"Encoding scene {scene_name}...")
|
| 207 |
+
|
| 208 |
+
# 加载和编码视频
|
| 209 |
+
|
| 210 |
+
# 编码视频
|
| 211 |
+
# with torch.no_grad():
|
| 212 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 213 |
+
|
| 214 |
+
# # 编码文本
|
| 215 |
+
# if processed_count == 0:
|
| 216 |
+
# print('encode prompt!!!')
|
| 217 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 218 |
+
# del encoder.pipe.prompter
|
| 219 |
+
# # pdb.set_trace()
|
| 220 |
+
# # 保存编码结果
|
| 221 |
+
# encoded_data = {
|
| 222 |
+
# "latents": latents.cpu(),
|
| 223 |
+
# #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 224 |
+
# "cam_emb": cam_emb
|
| 225 |
+
# }
|
| 226 |
+
# # pdb.set_trace()
|
| 227 |
+
# torch.save(encoded_data, encoded_path)
|
| 228 |
+
# print(f"Saved encoded data: {encoded_path}")
|
| 229 |
+
# processed_count += 1
|
| 230 |
+
|
| 231 |
+
# except Exception as e:
|
| 232 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 233 |
+
# continue
|
| 234 |
+
|
| 235 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
parser = argparse.ArgumentParser()
|
| 239 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
|
| 240 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 241 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 242 |
+
parser.add_argument("--vae_path", type=str,
|
| 243 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 244 |
+
|
| 245 |
+
parser.add_argument("--output_dir",type=str,
|
| 246 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 247 |
+
|
| 248 |
+
args = parser.parse_args()
|
| 249 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_spatialvid.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
+
|
| 19 |
+
from scipy.spatial.transform import Slerp
|
| 20 |
+
from scipy.spatial.transform import Rotation as R
|
| 21 |
+
|
| 22 |
+
def interpolate_camera_poses(original_frames, original_poses, target_frames):
|
| 23 |
+
"""
|
| 24 |
+
对相机姿态进行插值,生成目标帧对应的姿态参数
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
original_frames: 原始帧索引列表,如[0,6,12,...]
|
| 28 |
+
original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
|
| 29 |
+
target_frames: 目标帧索引列表,如[0,4,8,12,...]
|
| 30 |
+
|
| 31 |
+
返回:
|
| 32 |
+
target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
|
| 33 |
+
"""
|
| 34 |
+
# 确保输入有效
|
| 35 |
+
print('original_frames:',len(original_frames))
|
| 36 |
+
print('original_poses:',len(original_poses))
|
| 37 |
+
if len(original_frames) != len(original_poses):
|
| 38 |
+
raise ValueError("原始帧数量与姿态数量不匹配")
|
| 39 |
+
|
| 40 |
+
if original_poses.shape[1] != 7:
|
| 41 |
+
raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
|
| 42 |
+
|
| 43 |
+
target_poses = []
|
| 44 |
+
|
| 45 |
+
# 提取旋转部分并转换为Rotation对象
|
| 46 |
+
rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
|
| 47 |
+
|
| 48 |
+
for t in target_frames:
|
| 49 |
+
# 找到t前后的原始帧索引
|
| 50 |
+
idx = np.searchsorted(original_frames, t, side='left')
|
| 51 |
+
|
| 52 |
+
# 处理边界情况
|
| 53 |
+
if idx == 0:
|
| 54 |
+
# 使用第一个姿态
|
| 55 |
+
target_poses.append(original_poses[0])
|
| 56 |
+
continue
|
| 57 |
+
if idx >= len(original_frames):
|
| 58 |
+
# 使用最后一个姿态
|
| 59 |
+
target_poses.append(original_poses[-1])
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 获取前后帧的信息
|
| 63 |
+
t_prev, t_next = original_frames[idx-1], original_frames[idx]
|
| 64 |
+
pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
|
| 65 |
+
|
| 66 |
+
# 计算插值权重
|
| 67 |
+
alpha = (t - t_prev) / (t_next - t_prev)
|
| 68 |
+
|
| 69 |
+
# 1. 平移向量的线性插值
|
| 70 |
+
translation_prev = pose_prev[:3]
|
| 71 |
+
translation_next = pose_next[:3]
|
| 72 |
+
interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
|
| 73 |
+
|
| 74 |
+
# 2. 旋转四元数的球面线性插值(SLERP)
|
| 75 |
+
# 创建Slerp对象
|
| 76 |
+
slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
|
| 77 |
+
interpolated_rotation = slerp(t)
|
| 78 |
+
|
| 79 |
+
# 组合平移和旋转
|
| 80 |
+
interpolated_pose = np.concatenate([
|
| 81 |
+
interpolated_translation,
|
| 82 |
+
interpolated_rotation.as_quat() # 转换回四元数
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
target_poses.append(interpolated_pose)
|
| 86 |
+
|
| 87 |
+
return np.array(target_poses)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VideoEncoder(pl.LightningModule):
|
| 91 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 92 |
+
super().__init__()
|
| 93 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 94 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 95 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 96 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 97 |
+
|
| 98 |
+
self.frame_process = v2.Compose([
|
| 99 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 100 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 101 |
+
v2.ToTensor(),
|
| 102 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
def crop_and_resize(self, image):
|
| 106 |
+
width, height = image.size
|
| 107 |
+
# print(width,height)
|
| 108 |
+
width_ori, height_ori_ = 832 , 480
|
| 109 |
+
image = v2.functional.resize(
|
| 110 |
+
image,
|
| 111 |
+
(round(height_ori_), round(width_ori)),
|
| 112 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 113 |
+
)
|
| 114 |
+
return image
|
| 115 |
+
|
| 116 |
+
def load_video_frames(self, video_path):
|
| 117 |
+
"""加载完整视频"""
|
| 118 |
+
reader = imageio.get_reader(video_path)
|
| 119 |
+
frames = []
|
| 120 |
+
|
| 121 |
+
for frame_data in reader:
|
| 122 |
+
frame = Image.fromarray(frame_data)
|
| 123 |
+
frame = self.crop_and_resize(frame)
|
| 124 |
+
frame = self.frame_process(frame)
|
| 125 |
+
frames.append(frame)
|
| 126 |
+
|
| 127 |
+
reader.close()
|
| 128 |
+
|
| 129 |
+
if len(frames) == 0:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
frames = torch.stack(frames, dim=0)
|
| 133 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 134 |
+
return frames
|
| 135 |
+
|
| 136 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 137 |
+
"""编码所有场景的视频"""
|
| 138 |
+
|
| 139 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 140 |
+
encoder = encoder.cuda()
|
| 141 |
+
encoder.pipe.device = "cuda"
|
| 142 |
+
|
| 143 |
+
processed_count = 0
|
| 144 |
+
|
| 145 |
+
processed_chunk_count = 0
|
| 146 |
+
|
| 147 |
+
prompt_emb = 0
|
| 148 |
+
|
| 149 |
+
metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 153 |
+
chunk_size = 300
|
| 154 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 155 |
+
|
| 156 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 157 |
+
# print('index-----:',type(i))
|
| 158 |
+
if i < 3 :#or i >=2000:
|
| 159 |
+
# # print('index-----:',i)
|
| 160 |
+
continue
|
| 161 |
+
# print('index:',i)
|
| 162 |
+
print('group:',i)
|
| 163 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 164 |
+
|
| 165 |
+
# save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 166 |
+
print('in:',scene_dir)
|
| 167 |
+
# print('out:',save_dir)
|
| 168 |
+
for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 169 |
+
|
| 170 |
+
# if j < 1000 :#or i >=2000:
|
| 171 |
+
# print('index:',j)
|
| 172 |
+
# continue
|
| 173 |
+
print(video_name)
|
| 174 |
+
video_path = os.path.join(scene_dir, video_name)
|
| 175 |
+
if not video_path.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
video_info = metadata[metadata['id'] == video_name[:-4]]
|
| 179 |
+
num_frames = video_info['num frames'].iloc[0]
|
| 180 |
+
|
| 181 |
+
scene_cam_dir = video_path.replace( "videos","annotations")[:-4]
|
| 182 |
+
scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
|
| 183 |
+
|
| 184 |
+
scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
|
| 185 |
+
|
| 186 |
+
with open(scene_caption_path, 'r', encoding='utf-8') as f:
|
| 187 |
+
caption_data = json.load(f)
|
| 188 |
+
caption = caption_data["SceneSummary"]
|
| 189 |
+
if not os.path.exists(scene_cam_path):
|
| 190 |
+
print(f"Pose not found: {scene_cam_path}")
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
camera_poses = np.load(scene_cam_path)
|
| 194 |
+
cam_data_len = camera_poses.shape[0]
|
| 195 |
+
|
| 196 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 197 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 198 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 199 |
+
|
| 200 |
+
# 加载视频
|
| 201 |
+
# video_path = scene_dir
|
| 202 |
+
if not os.path.exists(video_path):
|
| 203 |
+
print(f"Video not found: {video_path}")
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
start_str = f"{0:07d}"
|
| 207 |
+
end_str = f"{chunk_size:07d}"
|
| 208 |
+
chunk_name = f"{video_name[:-4]}_{start_str}_{end_str}"
|
| 209 |
+
first_save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 210 |
+
|
| 211 |
+
first_chunk_encoded_path = os.path.join(first_save_chunk_dir, "encoded_video.pth")
|
| 212 |
+
# print(first_chunk_encoded_path)
|
| 213 |
+
if os.path.exists(first_chunk_encoded_path):
|
| 214 |
+
data = torch.load(first_chunk_encoded_path,weights_only=False)
|
| 215 |
+
if 'latents' in data:
|
| 216 |
+
video_frames = 1
|
| 217 |
+
else:
|
| 218 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 219 |
+
if video_frames is None:
|
| 220 |
+
print(f"Failed to load video: {video_path}")
|
| 221 |
+
continue
|
| 222 |
+
print('video shape:',video_frames.shape)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 227 |
+
print('video shape:',video_frames.shape)
|
| 228 |
+
|
| 229 |
+
video_name = video_name[:-4].split('_')[0]
|
| 230 |
+
start_frame = 0
|
| 231 |
+
end_frame = num_frames
|
| 232 |
+
# print("num_frames:",num_frames)
|
| 233 |
+
|
| 234 |
+
cam_interval = end_frame // (cam_data_len - 1)
|
| 235 |
+
|
| 236 |
+
cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
|
| 237 |
+
cam_frames = np.round(cam_frames).astype(int)
|
| 238 |
+
cam_frames = cam_frames.tolist()
|
| 239 |
+
# list(range(0, end_frame + 1 , cam_interval))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
sampled_range = range(start_frame, end_frame , chunk_size)
|
| 243 |
+
sampled_frames = list(sampled_range)
|
| 244 |
+
|
| 245 |
+
sampled_chunk_end = sampled_frames[0] + chunk_size
|
| 246 |
+
start_str = f"{sampled_frames[0]:07d}"
|
| 247 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 248 |
+
|
| 249 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 250 |
+
# save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
|
| 251 |
+
|
| 252 |
+
# if os.path.exists(save_chunk_path):
|
| 253 |
+
# print(f"Video {video_name} already encoded, skipping...")
|
| 254 |
+
# continue
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# print(sampled_frames)
|
| 261 |
+
|
| 262 |
+
print(f"Encoding scene {video_name}...")
|
| 263 |
+
chunk_count_in_one_video = 0
|
| 264 |
+
for sampled_chunk_start in sampled_frames:
|
| 265 |
+
if num_frames - sampled_chunk_start < 100:
|
| 266 |
+
continue
|
| 267 |
+
sampled_chunk_end = sampled_chunk_start + chunk_size
|
| 268 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 269 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 270 |
+
|
| 271 |
+
resample_cam_frame = list(range(sampled_chunk_start, sampled_chunk_end , 4))
|
| 272 |
+
|
| 273 |
+
# 生成保存目录名(假设video_name已定义)
|
| 274 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 275 |
+
save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 276 |
+
|
| 277 |
+
os.makedirs(save_chunk_dir,exist_ok=True)
|
| 278 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 279 |
+
|
| 280 |
+
encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
|
| 281 |
+
|
| 282 |
+
missing_keys = required_keys
|
| 283 |
+
if os.path.exists(encoded_path):
|
| 284 |
+
print('error:',encoded_path)
|
| 285 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 286 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 287 |
+
# print(missing_keys)
|
| 288 |
+
# print(f"Chunk {chunk_name} already encoded, skipping...")
|
| 289 |
+
if missing_keys:
|
| 290 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 291 |
+
if len(missing_keys) == 0 :
|
| 292 |
+
continue
|
| 293 |
+
else:
|
| 294 |
+
print(f"警告: 缺少pth文件: {encoded_path}")
|
| 295 |
+
if not isinstance(video_frames, torch.Tensor):
|
| 296 |
+
|
| 297 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 298 |
+
if video_frames is None:
|
| 299 |
+
print(f"Failed to load video: {video_path}")
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 303 |
+
|
| 304 |
+
print('video shape:',video_frames.shape)
|
| 305 |
+
if "latents" in missing_keys:
|
| 306 |
+
chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
|
| 307 |
+
|
| 308 |
+
# print('extrinsic:',cam_emb['extrinsic'].shape)
|
| 309 |
+
|
| 310 |
+
# chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
|
| 311 |
+
# 'intrinsic':cam_emb['intrinsic']}
|
| 312 |
+
|
| 313 |
+
# print('chunk shape:',chunk_frames.shape)
|
| 314 |
+
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
|
| 317 |
+
else:
|
| 318 |
+
latents = data['latents']
|
| 319 |
+
if "cam_emb" in missing_keys:
|
| 320 |
+
cam_emb = interpolate_camera_poses(cam_frames, camera_poses,resample_cam_frame)
|
| 321 |
+
chunk_cam_emb ={'extrinsic':cam_emb}
|
| 322 |
+
print(f"视频长度:{chunk_size},重采样相机长度:{cam_emb.shape[0]}")
|
| 323 |
+
else:
|
| 324 |
+
chunk_cam_emb = data['cam_emb']
|
| 325 |
+
|
| 326 |
+
if "prompt_emb" in missing_keys:
|
| 327 |
+
# 编码文本
|
| 328 |
+
if chunk_count_in_one_video == 0:
|
| 329 |
+
print(caption)
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
prompt_emb = encoder.pipe.encode_prompt(caption)
|
| 332 |
+
else:
|
| 333 |
+
prompt_emb = data['prompt_emb']
|
| 334 |
+
|
| 335 |
+
# del encoder.pipe.prompter
|
| 336 |
+
# pdb.set_trace()
|
| 337 |
+
# 保存编码结果
|
| 338 |
+
encoded_data = {
|
| 339 |
+
"latents": latents.cpu(),
|
| 340 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 341 |
+
"cam_emb": chunk_cam_emb
|
| 342 |
+
}
|
| 343 |
+
# pdb.set_trace()
|
| 344 |
+
torch.save(encoded_data, encoded_path)
|
| 345 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 346 |
+
processed_chunk_count += 1
|
| 347 |
+
chunk_count_in_one_video += 1
|
| 348 |
+
|
| 349 |
+
processed_count += 1
|
| 350 |
+
|
| 351 |
+
print("Encoded scene numebr:",processed_count)
|
| 352 |
+
print("Encoded chunk numebr:",processed_chunk_count)
|
| 353 |
+
|
| 354 |
+
# os.makedirs(save_dir,exist_ok=True)
|
| 355 |
+
# # 检查是否已编码
|
| 356 |
+
# encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 357 |
+
# if os.path.exists(encoded_path):
|
| 358 |
+
# print(f"Scene {scene_name} already encoded, skipping...")
|
| 359 |
+
# continue
|
| 360 |
+
|
| 361 |
+
# 加载场景信息
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# try:
|
| 366 |
+
# print(f"Encoding scene {scene_name}...")
|
| 367 |
+
|
| 368 |
+
# 加载和编码视频
|
| 369 |
+
|
| 370 |
+
# 编码视频
|
| 371 |
+
# with torch.no_grad():
|
| 372 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 373 |
+
|
| 374 |
+
# # 编码文本
|
| 375 |
+
# if processed_count == 0:
|
| 376 |
+
# print('encode prompt!!!')
|
| 377 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 378 |
+
# del encoder.pipe.prompter
|
| 379 |
+
# # pdb.set_trace()
|
| 380 |
+
# # 保存编码结果
|
| 381 |
+
# encoded_data = {
|
| 382 |
+
# "latents": latents.cpu(),
|
| 383 |
+
# #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 384 |
+
# "cam_emb": cam_emb
|
| 385 |
+
# }
|
| 386 |
+
# # pdb.set_trace()
|
| 387 |
+
# torch.save(encoded_data, encoded_path)
|
| 388 |
+
# print(f"Saved encoded data: {encoded_path}")
|
| 389 |
+
# processed_count += 1
|
| 390 |
+
|
| 391 |
+
# except Exception as e:
|
| 392 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 393 |
+
# continue
|
| 394 |
+
|
| 395 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
parser = argparse.ArgumentParser()
|
| 399 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
|
| 400 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 401 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 402 |
+
parser.add_argument("--vae_path", type=str,
|
| 403 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 404 |
+
|
| 405 |
+
parser.add_argument("--output_dir",type=str,
|
| 406 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 407 |
+
|
| 408 |
+
args = parser.parse_args()
|
| 409 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_spatialvid_first_frame.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
+
|
| 19 |
+
from scipy.spatial.transform import Slerp
|
| 20 |
+
from scipy.spatial.transform import Rotation as R
|
| 21 |
+
|
| 22 |
+
def interpolate_camera_poses(original_frames, original_poses, target_frames):
|
| 23 |
+
"""
|
| 24 |
+
对相机姿态进行插值,生成目标帧对应的姿态参数
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
original_frames: 原始帧索引列表,如[0,6,12,...]
|
| 28 |
+
original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
|
| 29 |
+
target_frames: 目标帧索引列表,如[0,4,8,12,...]
|
| 30 |
+
|
| 31 |
+
返回:
|
| 32 |
+
target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
|
| 33 |
+
"""
|
| 34 |
+
# 确保输入有效
|
| 35 |
+
print('original_frames:',len(original_frames))
|
| 36 |
+
print('original_poses:',len(original_poses))
|
| 37 |
+
if len(original_frames) != len(original_poses):
|
| 38 |
+
raise ValueError("原始帧数量与姿态数量不匹配")
|
| 39 |
+
|
| 40 |
+
if original_poses.shape[1] != 7:
|
| 41 |
+
raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
|
| 42 |
+
|
| 43 |
+
target_poses = []
|
| 44 |
+
|
| 45 |
+
# 提取旋转部分并转换为Rotation对象
|
| 46 |
+
rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
|
| 47 |
+
|
| 48 |
+
for t in target_frames:
|
| 49 |
+
# 找到t前后的原始帧索引
|
| 50 |
+
idx = np.searchsorted(original_frames, t, side='left')
|
| 51 |
+
|
| 52 |
+
# 处理边界情况
|
| 53 |
+
if idx == 0:
|
| 54 |
+
# 使用第一个姿态
|
| 55 |
+
target_poses.append(original_poses[0])
|
| 56 |
+
continue
|
| 57 |
+
if idx >= len(original_frames):
|
| 58 |
+
# 使用最后一个姿态
|
| 59 |
+
target_poses.append(original_poses[-1])
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 获取前后帧的信息
|
| 63 |
+
t_prev, t_next = original_frames[idx-1], original_frames[idx]
|
| 64 |
+
pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
|
| 65 |
+
|
| 66 |
+
# 计算插值权重
|
| 67 |
+
alpha = (t - t_prev) / (t_next - t_prev)
|
| 68 |
+
|
| 69 |
+
# 1. 平移向量的线性插值
|
| 70 |
+
translation_prev = pose_prev[:3]
|
| 71 |
+
translation_next = pose_next[:3]
|
| 72 |
+
interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
|
| 73 |
+
|
| 74 |
+
# 2. 旋转四元数的球面线性插值(SLERP)
|
| 75 |
+
# 创建Slerp对象
|
| 76 |
+
slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
|
| 77 |
+
interpolated_rotation = slerp(t)
|
| 78 |
+
|
| 79 |
+
# 组合平移和旋转
|
| 80 |
+
interpolated_pose = np.concatenate([
|
| 81 |
+
interpolated_translation,
|
| 82 |
+
interpolated_rotation.as_quat() # 转换回四元数
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
target_poses.append(interpolated_pose)
|
| 86 |
+
|
| 87 |
+
return np.array(target_poses)
|
| 88 |
+
|
| 89 |
+
class VideoEncoder(pl.LightningModule):
|
| 90 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 91 |
+
super().__init__()
|
| 92 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 93 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 94 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 95 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 96 |
+
|
| 97 |
+
self.frame_process = v2.Compose([
|
| 98 |
+
v2.ToTensor(),
|
| 99 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
def crop_and_resize(self, image):
|
| 103 |
+
width, height = image.size
|
| 104 |
+
width_ori, height_ori_ = 832 , 480
|
| 105 |
+
image = v2.functional.resize(
|
| 106 |
+
image,
|
| 107 |
+
(round(height_ori_), round(width_ori)),
|
| 108 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 109 |
+
)
|
| 110 |
+
return image
|
| 111 |
+
|
| 112 |
+
def load_single_frame(self, video_path, frame_idx):
|
| 113 |
+
"""只加载指定的单帧"""
|
| 114 |
+
reader = imageio.get_reader(video_path)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# 直接跳转到指定帧
|
| 118 |
+
frame_data = reader.get_data(frame_idx)
|
| 119 |
+
frame = Image.fromarray(frame_data)
|
| 120 |
+
frame = self.crop_and_resize(frame)
|
| 121 |
+
frame = self.frame_process(frame)
|
| 122 |
+
|
| 123 |
+
# 添加batch和time维度: [C, H, W] -> [1, C, 1, H, W]
|
| 124 |
+
frame = frame.unsqueeze(0).unsqueeze(2)
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error loading frame {frame_idx} from {video_path}: {e}")
|
| 128 |
+
return None
|
| 129 |
+
finally:
|
| 130 |
+
reader.close()
|
| 131 |
+
|
| 132 |
+
return frame
|
| 133 |
+
|
| 134 |
+
def load_video_frames(self, video_path):
|
| 135 |
+
"""加载完整视频(保留用于兼容性)"""
|
| 136 |
+
reader = imageio.get_reader(video_path)
|
| 137 |
+
frames = []
|
| 138 |
+
|
| 139 |
+
for frame_data in reader:
|
| 140 |
+
frame = Image.fromarray(frame_data)
|
| 141 |
+
frame = self.crop_and_resize(frame)
|
| 142 |
+
frame = self.frame_process(frame)
|
| 143 |
+
frames.append(frame)
|
| 144 |
+
|
| 145 |
+
reader.close()
|
| 146 |
+
|
| 147 |
+
if len(frames) == 0:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
frames = torch.stack(frames, dim=0)
|
| 151 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 152 |
+
return frames
|
| 153 |
+
|
| 154 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 155 |
+
"""编码所有场景的视频"""
|
| 156 |
+
|
| 157 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 158 |
+
encoder = encoder.cuda()
|
| 159 |
+
encoder.pipe.device = "cuda"
|
| 160 |
+
|
| 161 |
+
processed_count = 0
|
| 162 |
+
processed_chunk_count = 0
|
| 163 |
+
|
| 164 |
+
metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
|
| 165 |
+
|
| 166 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 167 |
+
chunk_size = 300
|
| 168 |
+
|
| 169 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 170 |
+
if i < 2:
|
| 171 |
+
continue
|
| 172 |
+
print('group:',i)
|
| 173 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 174 |
+
|
| 175 |
+
print('in:',scene_dir)
|
| 176 |
+
for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 177 |
+
print(video_name)
|
| 178 |
+
video_path = os.path.join(scene_dir, video_name)
|
| 179 |
+
if not video_path.endswith(".mp4"):
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
video_info = metadata[metadata['id'] == video_name[:-4]]
|
| 183 |
+
num_frames = video_info['num frames'].iloc[0]
|
| 184 |
+
|
| 185 |
+
scene_cam_dir = video_path.replace("videos","annotations")[:-4]
|
| 186 |
+
scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
|
| 187 |
+
scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
|
| 188 |
+
|
| 189 |
+
with open(scene_caption_path, 'r', encoding='utf-8') as f:
|
| 190 |
+
caption_data = json.load(f)
|
| 191 |
+
caption = caption_data["SceneSummary"]
|
| 192 |
+
|
| 193 |
+
if not os.path.exists(scene_cam_path):
|
| 194 |
+
print(f"Pose not found: {scene_cam_path}")
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
camera_poses = np.load(scene_cam_path)
|
| 198 |
+
cam_data_len = camera_poses.shape[0]
|
| 199 |
+
|
| 200 |
+
if not os.path.exists(video_path):
|
| 201 |
+
print(f"Video not found: {video_path}")
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
video_name = video_name[:-4].split('_')[0]
|
| 205 |
+
start_frame = 0
|
| 206 |
+
end_frame = num_frames
|
| 207 |
+
|
| 208 |
+
cam_interval = end_frame // (cam_data_len - 1)
|
| 209 |
+
|
| 210 |
+
cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
|
| 211 |
+
cam_frames = np.round(cam_frames).astype(int)
|
| 212 |
+
cam_frames = cam_frames.tolist()
|
| 213 |
+
|
| 214 |
+
sampled_range = range(start_frame, end_frame, chunk_size)
|
| 215 |
+
sampled_frames = list(sampled_range)
|
| 216 |
+
|
| 217 |
+
print(f"Encoding scene {video_name}...")
|
| 218 |
+
chunk_count_in_one_video = 0
|
| 219 |
+
|
| 220 |
+
for sampled_chunk_start in sampled_frames:
|
| 221 |
+
if num_frames - sampled_chunk_start < 100:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
sampled_chunk_end = sampled_chunk_start + chunk_size
|
| 225 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 226 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 227 |
+
|
| 228 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 229 |
+
save_chunk_dir = os.path.join(output_dir, chunk_name)
|
| 230 |
+
os.makedirs(save_chunk_dir, exist_ok=True)
|
| 231 |
+
|
| 232 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 233 |
+
|
| 234 |
+
first_latent_path = os.path.join(save_chunk_dir, "first_latent.pth")
|
| 235 |
+
|
| 236 |
+
if os.path.exists(first_latent_path):
|
| 237 |
+
print(f"First latent for chunk {chunk_name} already exists, skipping...")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# 只加载需要的那一帧
|
| 241 |
+
first_frame_idx = sampled_chunk_start
|
| 242 |
+
print(f"first_frame:{first_frame_idx}")
|
| 243 |
+
first_frame = encoder.load_single_frame(video_path, first_frame_idx)
|
| 244 |
+
|
| 245 |
+
if first_frame is None:
|
| 246 |
+
print(f"Failed to load frame {first_frame_idx} from: {video_path}")
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
first_frame = first_frame.to("cuda", dtype=torch.bfloat16)
|
| 250 |
+
|
| 251 |
+
# 重复4次
|
| 252 |
+
repeated_first_frame = first_frame.repeat(1, 1, 4, 1, 1)
|
| 253 |
+
print(f"Repeated first frame shape: {repeated_first_frame.shape}")
|
| 254 |
+
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
first_latents = encoder.pipe.encode_video(repeated_first_frame, **encoder.tiler_kwargs)[0]
|
| 257 |
+
|
| 258 |
+
first_latent_data = {
|
| 259 |
+
"latents": first_latents.cpu(),
|
| 260 |
+
}
|
| 261 |
+
torch.save(first_latent_data, first_latent_path)
|
| 262 |
+
print(f"Saved first latent: {first_latent_path}")
|
| 263 |
+
|
| 264 |
+
processed_chunk_count += 1
|
| 265 |
+
chunk_count_in_one_video += 1
|
| 266 |
+
|
| 267 |
+
processed_count += 1
|
| 268 |
+
print("Encoded scene number:", processed_count)
|
| 269 |
+
print("Encoded chunk number:", processed_chunk_count)
|
| 270 |
+
|
| 271 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
parser = argparse.ArgumentParser()
|
| 275 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
|
| 276 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 277 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 278 |
+
parser.add_argument("--vae_path", type=str,
|
| 279 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 280 |
+
|
| 281 |
+
parser.add_argument("--output_dir",type=str,
|
| 282 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 283 |
+
|
| 284 |
+
args = parser.parse_args()
|
| 285 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/hud_logo.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
os.makedirs("wasd_ui", exist_ok=True)
|
| 5 |
+
|
| 6 |
+
# UI sizes (small)
|
| 7 |
+
key_size = (48, 48)
|
| 8 |
+
corner = 10
|
| 9 |
+
bg_padding = 6
|
| 10 |
+
font = ImageFont.truetype("arial.ttf", 28) # 替换成本地支持的字体
|
| 11 |
+
|
| 12 |
+
def rounded_rect(im, bbox, radius, fill):
|
| 13 |
+
draw = ImageDraw.Draw(im, "RGBA")
|
| 14 |
+
draw.rounded_rectangle(bbox, radius=radius, fill=fill)
|
| 15 |
+
|
| 16 |
+
# background plate
|
| 17 |
+
bg_width = key_size[0] * 3 + bg_padding * 4
|
| 18 |
+
bg_height = key_size[1] * 2 + bg_padding * 4
|
| 19 |
+
ui_bg = Image.new("RGBA", (bg_width, bg_height), (0,0,0,0))
|
| 20 |
+
rounded_rect(ui_bg, (0,0,bg_width,bg_height), corner, (0,0,0,140))
|
| 21 |
+
ui_bg.save("wasd_ui/ui_background.png")
|
| 22 |
+
|
| 23 |
+
keys = ["W","A","S","D"]
|
| 24 |
+
|
| 25 |
+
def draw_key(char, active):
|
| 26 |
+
im = Image.new("RGBA", key_size, (0,0,0,0))
|
| 27 |
+
rounded_rect(im, (0,0,key_size[0],key_size[1]), corner,
|
| 28 |
+
(255,255,255,230) if active else (200,200,200,180))
|
| 29 |
+
draw = ImageDraw.Draw(im)
|
| 30 |
+
color = (0,0,0) if active else (50,50,50)
|
| 31 |
+
w,h = draw.textsize(char, font=font)
|
| 32 |
+
draw.text(((key_size[0]-w)//2,(key_size[1]-h)//2),
|
| 33 |
+
char, font=font, fill=color)
|
| 34 |
+
return im
|
| 35 |
+
|
| 36 |
+
for k in keys:
|
| 37 |
+
draw_key(k, False).save(f"wasd_ui/key_{k}_idle.png")
|
| 38 |
+
draw_key(k, True).save(f"wasd_ui/key_{k}_active.png")
|
| 39 |
+
|
| 40 |
+
print("✅ WASD UI assets generated in ./wasd_ui/")
|
scripts/infer_demo.py
ADDED
|
@@ -0,0 +1,1458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 5 |
+
sys.path.append(ROOT_DIR)
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import imageio
|
| 12 |
+
import json
|
| 13 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 14 |
+
import argparse
|
| 15 |
+
from torchvision.transforms import v2
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
import random
|
| 18 |
+
import copy
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 22 |
+
"""
|
| 23 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 24 |
+
|
| 25 |
+
参数:
|
| 26 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 27 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 28 |
+
|
| 29 |
+
返回:
|
| 30 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 31 |
+
"""
|
| 32 |
+
# 分离平移向量和四元数
|
| 33 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 34 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 35 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 36 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 37 |
+
|
| 38 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 39 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 40 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 41 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 42 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 43 |
+
|
| 44 |
+
# 2. 计算相对平移向量 t_rel
|
| 45 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 46 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 47 |
+
|
| 48 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 49 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 50 |
+
|
| 51 |
+
return relative_matrix
|
| 52 |
+
|
| 53 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 54 |
+
"""从pth文件加载预编码的视频数据"""
|
| 55 |
+
print(f"Loading encoded video from {pth_path}")
|
| 56 |
+
|
| 57 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 58 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 59 |
+
|
| 60 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 61 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 62 |
+
|
| 63 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 64 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 65 |
+
|
| 66 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 67 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 68 |
+
|
| 69 |
+
return condition_latents, encoded_data
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 73 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 74 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 75 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 76 |
+
|
| 77 |
+
if use_torch:
|
| 78 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 79 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 80 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 81 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 82 |
+
|
| 83 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 84 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 85 |
+
else:
|
| 86 |
+
if not isinstance(pose_a, np.ndarray):
|
| 87 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 88 |
+
if not isinstance(pose_b, np.ndarray):
|
| 89 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 90 |
+
|
| 91 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 92 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 93 |
+
|
| 94 |
+
return relative_pose
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def replace_dit_model_in_manager():
|
| 98 |
+
"""替换DiT模型类为MoE版本"""
|
| 99 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 100 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 101 |
+
|
| 102 |
+
for i, config in enumerate(model_loader_configs):
|
| 103 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 104 |
+
|
| 105 |
+
if 'wan_video_dit' in model_names:
|
| 106 |
+
new_model_names = []
|
| 107 |
+
new_model_classes = []
|
| 108 |
+
|
| 109 |
+
for name, cls in zip(model_names, model_classes):
|
| 110 |
+
if name == 'wan_video_dit':
|
| 111 |
+
new_model_names.append(name)
|
| 112 |
+
new_model_classes.append(WanModelMoe)
|
| 113 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 114 |
+
else:
|
| 115 |
+
new_model_names.append(name)
|
| 116 |
+
new_model_classes.append(cls)
|
| 117 |
+
|
| 118 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def add_framepack_components(dit_model):
|
| 122 |
+
"""添加FramePack相关组件"""
|
| 123 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 124 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 125 |
+
|
| 126 |
+
class CleanXEmbedder(nn.Module):
|
| 127 |
+
def __init__(self, inner_dim):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 130 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 131 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 132 |
+
|
| 133 |
+
def forward(self, x, scale="1x"):
|
| 134 |
+
if scale == "1x":
|
| 135 |
+
x = x.to(self.proj.weight.dtype)
|
| 136 |
+
return self.proj(x)
|
| 137 |
+
elif scale == "2x":
|
| 138 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 139 |
+
return self.proj_2x(x)
|
| 140 |
+
elif scale == "4x":
|
| 141 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 142 |
+
return self.proj_4x(x)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 145 |
+
|
| 146 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 147 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 148 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 149 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def add_moe_components(dit_model, moe_config):
|
| 153 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 154 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 155 |
+
dit_model.moe_config = moe_config
|
| 156 |
+
print("✅ 添加了MoE配置到模型")
|
| 157 |
+
dit_model.top_k = moe_config.get("top_k", 1)
|
| 158 |
+
|
| 159 |
+
# 为每个block动态添加MoE组件
|
| 160 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 161 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 162 |
+
num_experts = moe_config.get("num_experts", 4)
|
| 163 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 164 |
+
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 165 |
+
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 166 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 167 |
+
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
for i, block in enumerate(dit_model.blocks):
|
| 171 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 172 |
+
block.moe = MultiModalMoE(
|
| 173 |
+
unified_dim=unified_dim,
|
| 174 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 175 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 176 |
+
top_k=moe_config.get("top_k", 2)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def generate_sekai_camera_embeddings_sliding(
|
| 183 |
+
cam_data,
|
| 184 |
+
start_frame,
|
| 185 |
+
initial_condition_frames,
|
| 186 |
+
new_frames,
|
| 187 |
+
total_generated,
|
| 188 |
+
use_real_poses=True,
|
| 189 |
+
direction="left"):
|
| 190 |
+
"""
|
| 191 |
+
为Sekai数据集生成camera embeddings - 滑动窗口版本
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
cam_data: 包含Sekai相机外参的字典, 键'extrinsic'对应一个N*4*4的numpy数组
|
| 195 |
+
start_frame: 当前生成起始帧索引
|
| 196 |
+
initial_condition_frames: 初始条件帧数
|
| 197 |
+
new_frames: 本次生成的新帧数
|
| 198 |
+
total_generated: 已生成的总帧数
|
| 199 |
+
use_real_poses: 是否使用真实的Sekai相机位姿
|
| 200 |
+
direction: 相机运动方向,默认为"left"
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
camera_embedding: 形状为(M, 3*4 + 1)的torch张量, M为生成的总帧数
|
| 204 |
+
"""
|
| 205 |
+
time_compression_ratio = 4
|
| 206 |
+
|
| 207 |
+
# 计算FramePack实际需要的camera帧数
|
| 208 |
+
# 1帧初始 + 16帧4x + 2帧2x + 1帧1x + new_frames
|
| 209 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 210 |
+
|
| 211 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 212 |
+
print("🔧 使用真实Sekai camera数据")
|
| 213 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 214 |
+
|
| 215 |
+
# 确保生成足够长的camera序列
|
| 216 |
+
max_needed_frames = max(
|
| 217 |
+
start_frame + initial_condition_frames + new_frames,
|
| 218 |
+
framepack_needed_frames,
|
| 219 |
+
30
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 223 |
+
print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
|
| 224 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 225 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 226 |
+
|
| 227 |
+
relative_poses = []
|
| 228 |
+
for i in range(max_needed_frames):
|
| 229 |
+
# 计算当前帧在原始序列中的位置
|
| 230 |
+
frame_idx = i * time_compression_ratio
|
| 231 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 232 |
+
|
| 233 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 234 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 235 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 236 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 237 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 238 |
+
else:
|
| 239 |
+
# 超出范围,使用零运动
|
| 240 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 241 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 242 |
+
|
| 243 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 244 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 245 |
+
|
| 246 |
+
# 创建对应长度的mask序列
|
| 247 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 248 |
+
# 从start_frame到start_frame+initial_condition_frames标记为condition
|
| 249 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 250 |
+
mask[start_frame:condition_end] = 1.0
|
| 251 |
+
|
| 252 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 253 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 254 |
+
return camera_embedding.to(torch.bfloat16)
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
# 确保生成足够长的camera序列
|
| 258 |
+
max_needed_frames = max(
|
| 259 |
+
start_frame + initial_condition_frames + new_frames,
|
| 260 |
+
framepack_needed_frames,
|
| 261 |
+
30)
|
| 262 |
+
|
| 263 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 264 |
+
|
| 265 |
+
CONDITION_FRAMES = initial_condition_frames
|
| 266 |
+
STAGE_1 = new_frames//2
|
| 267 |
+
STAGE_2 = new_frames - STAGE_1
|
| 268 |
+
|
| 269 |
+
if direction=="left":
|
| 270 |
+
print("--------------- LEFT TURNING MODE ---------------")
|
| 271 |
+
relative_poses = []
|
| 272 |
+
for i in range(max_needed_frames):
|
| 273 |
+
if i < CONDITION_FRAMES:
|
| 274 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 275 |
+
pose = np.eye(4, dtype=np.float32)
|
| 276 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 277 |
+
# 左转
|
| 278 |
+
yaw_per_frame = 0.03
|
| 279 |
+
|
| 280 |
+
# 旋转矩阵
|
| 281 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 282 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 283 |
+
|
| 284 |
+
# 前进
|
| 285 |
+
forward_speed = 0.00
|
| 286 |
+
|
| 287 |
+
pose = np.eye(4, dtype=np.float32)
|
| 288 |
+
|
| 289 |
+
pose[0, 0] = cos_yaw
|
| 290 |
+
pose[0, 2] = sin_yaw
|
| 291 |
+
pose[2, 0] = -sin_yaw
|
| 292 |
+
pose[2, 2] = cos_yaw
|
| 293 |
+
pose[2, 3] = -forward_speed
|
| 294 |
+
else:
|
| 295 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 296 |
+
pose = np.eye(4, dtype=np.float32)
|
| 297 |
+
|
| 298 |
+
relative_pose = pose[:3, :]
|
| 299 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 300 |
+
|
| 301 |
+
elif direction=="right":
|
| 302 |
+
print("--------------- RIGHT TURNING MODE ---------------")
|
| 303 |
+
relative_poses = []
|
| 304 |
+
for i in range(max_needed_frames):
|
| 305 |
+
if i < CONDITION_FRAMES:
|
| 306 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 307 |
+
pose = np.eye(4, dtype=np.float32)
|
| 308 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 309 |
+
# 右转
|
| 310 |
+
yaw_per_frame = -0.03
|
| 311 |
+
|
| 312 |
+
# 旋转矩阵
|
| 313 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 314 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 315 |
+
|
| 316 |
+
# 前进
|
| 317 |
+
forward_speed = 0.00
|
| 318 |
+
|
| 319 |
+
pose = np.eye(4, dtype=np.float32)
|
| 320 |
+
|
| 321 |
+
pose[0, 0] = cos_yaw
|
| 322 |
+
pose[0, 2] = sin_yaw
|
| 323 |
+
pose[2, 0] = -sin_yaw
|
| 324 |
+
pose[2, 2] = cos_yaw
|
| 325 |
+
pose[2, 3] = -forward_speed
|
| 326 |
+
else:
|
| 327 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 328 |
+
pose = np.eye(4, dtype=np.float32)
|
| 329 |
+
|
| 330 |
+
relative_pose = pose[:3, :]
|
| 331 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 332 |
+
|
| 333 |
+
elif direction=="forward_left":
|
| 334 |
+
print("--------------- FORWARD LEFT MODE ---------------")
|
| 335 |
+
relative_poses = []
|
| 336 |
+
for i in range(max_needed_frames):
|
| 337 |
+
if i < CONDITION_FRAMES:
|
| 338 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 339 |
+
pose = np.eye(4, dtype=np.float32)
|
| 340 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 341 |
+
# 左转
|
| 342 |
+
yaw_per_frame = 0.03
|
| 343 |
+
|
| 344 |
+
# 旋转矩阵
|
| 345 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 346 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 347 |
+
|
| 348 |
+
# 前进
|
| 349 |
+
forward_speed = 0.03
|
| 350 |
+
|
| 351 |
+
pose = np.eye(4, dtype=np.float32)
|
| 352 |
+
|
| 353 |
+
pose[0, 0] = cos_yaw
|
| 354 |
+
pose[0, 2] = sin_yaw
|
| 355 |
+
pose[2, 0] = -sin_yaw
|
| 356 |
+
pose[2, 2] = cos_yaw
|
| 357 |
+
pose[2, 3] = -forward_speed
|
| 358 |
+
|
| 359 |
+
else:
|
| 360 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 361 |
+
pose = np.eye(4, dtype=np.float32)
|
| 362 |
+
|
| 363 |
+
relative_pose = pose[:3, :]
|
| 364 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 365 |
+
|
| 366 |
+
elif direction=="forward_right":
|
| 367 |
+
print("--------------- FORWARD RIGHT MODE ---------------")
|
| 368 |
+
relative_poses = []
|
| 369 |
+
for i in range(max_needed_frames):
|
| 370 |
+
if i < CONDITION_FRAMES:
|
| 371 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 372 |
+
pose = np.eye(4, dtype=np.float32)
|
| 373 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 374 |
+
# 右转
|
| 375 |
+
yaw_per_frame = -0.03
|
| 376 |
+
|
| 377 |
+
# 旋转矩阵
|
| 378 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 379 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 380 |
+
|
| 381 |
+
# 前进
|
| 382 |
+
forward_speed = 0.03
|
| 383 |
+
|
| 384 |
+
pose = np.eye(4, dtype=np.float32)
|
| 385 |
+
|
| 386 |
+
pose[0, 0] = cos_yaw
|
| 387 |
+
pose[0, 2] = sin_yaw
|
| 388 |
+
pose[2, 0] = -sin_yaw
|
| 389 |
+
pose[2, 2] = cos_yaw
|
| 390 |
+
pose[2, 3] = -forward_speed
|
| 391 |
+
|
| 392 |
+
else:
|
| 393 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 394 |
+
pose = np.eye(4, dtype=np.float32)
|
| 395 |
+
|
| 396 |
+
relative_pose = pose[:3, :]
|
| 397 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 398 |
+
|
| 399 |
+
elif direction=="s_curve":
|
| 400 |
+
print("--------------- S CURVE MODE ---------------")
|
| 401 |
+
relative_poses = []
|
| 402 |
+
for i in range(max_needed_frames):
|
| 403 |
+
if i < CONDITION_FRAMES:
|
| 404 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 405 |
+
pose = np.eye(4, dtype=np.float32)
|
| 406 |
+
elif i < CONDITION_FRAMES+STAGE_1:
|
| 407 |
+
# 左转
|
| 408 |
+
yaw_per_frame = 0.03
|
| 409 |
+
|
| 410 |
+
# 旋转矩阵
|
| 411 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 412 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 413 |
+
|
| 414 |
+
# 前进
|
| 415 |
+
forward_speed = 0.03
|
| 416 |
+
|
| 417 |
+
pose = np.eye(4, dtype=np.float32)
|
| 418 |
+
|
| 419 |
+
pose[0, 0] = cos_yaw
|
| 420 |
+
pose[0, 2] = sin_yaw
|
| 421 |
+
pose[2, 0] = -sin_yaw
|
| 422 |
+
pose[2, 2] = cos_yaw
|
| 423 |
+
pose[2, 3] = -forward_speed
|
| 424 |
+
|
| 425 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 426 |
+
# 右转
|
| 427 |
+
yaw_per_frame = -0.03
|
| 428 |
+
|
| 429 |
+
# 旋转矩阵
|
| 430 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 431 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 432 |
+
|
| 433 |
+
# 前进
|
| 434 |
+
forward_speed = 0.03
|
| 435 |
+
# 轻微向左漂移,保持惯性
|
| 436 |
+
if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
|
| 437 |
+
radius_shift = -0.01
|
| 438 |
+
else:
|
| 439 |
+
radius_shift = 0.00
|
| 440 |
+
|
| 441 |
+
pose = np.eye(4, dtype=np.float32)
|
| 442 |
+
|
| 443 |
+
pose[0, 0] = cos_yaw
|
| 444 |
+
pose[0, 2] = sin_yaw
|
| 445 |
+
pose[2, 0] = -sin_yaw
|
| 446 |
+
pose[2, 2] = cos_yaw
|
| 447 |
+
pose[2, 3] = -forward_speed
|
| 448 |
+
pose[0, 3] = radius_shift
|
| 449 |
+
|
| 450 |
+
else:
|
| 451 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 452 |
+
pose = np.eye(4, dtype=np.float32)
|
| 453 |
+
|
| 454 |
+
relative_pose = pose[:3, :]
|
| 455 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 456 |
+
|
| 457 |
+
elif direction=="left_right":
|
| 458 |
+
print("--------------- LEFT RIGHT MODE ---------------")
|
| 459 |
+
relative_poses = []
|
| 460 |
+
for i in range(max_needed_frames):
|
| 461 |
+
if i < CONDITION_FRAMES:
|
| 462 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 463 |
+
pose = np.eye(4, dtype=np.float32)
|
| 464 |
+
elif i < CONDITION_FRAMES+STAGE_1:
|
| 465 |
+
# 左转
|
| 466 |
+
yaw_per_frame = 0.03
|
| 467 |
+
|
| 468 |
+
# 旋转矩阵
|
| 469 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 470 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 471 |
+
|
| 472 |
+
# 前进
|
| 473 |
+
forward_speed = 0.00
|
| 474 |
+
|
| 475 |
+
pose = np.eye(4, dtype=np.float32)
|
| 476 |
+
|
| 477 |
+
pose[0, 0] = cos_yaw
|
| 478 |
+
pose[0, 2] = sin_yaw
|
| 479 |
+
pose[2, 0] = -sin_yaw
|
| 480 |
+
pose[2, 2] = cos_yaw
|
| 481 |
+
pose[2, 3] = -forward_speed
|
| 482 |
+
|
| 483 |
+
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 484 |
+
# 右转
|
| 485 |
+
yaw_per_frame = -0.03
|
| 486 |
+
|
| 487 |
+
# 旋转矩阵
|
| 488 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 489 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 490 |
+
|
| 491 |
+
# 前进
|
| 492 |
+
forward_speed = 0.00
|
| 493 |
+
|
| 494 |
+
pose = np.eye(4, dtype=np.float32)
|
| 495 |
+
|
| 496 |
+
pose[0, 0] = cos_yaw
|
| 497 |
+
pose[0, 2] = sin_yaw
|
| 498 |
+
pose[2, 0] = -sin_yaw
|
| 499 |
+
pose[2, 2] = cos_yaw
|
| 500 |
+
pose[2, 3] = -forward_speed
|
| 501 |
+
|
| 502 |
+
else:
|
| 503 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 504 |
+
pose = np.eye(4, dtype=np.float32)
|
| 505 |
+
|
| 506 |
+
relative_pose = pose[:3, :]
|
| 507 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
raise ValueError(f"未定义的相机运动方向: {direction}")
|
| 511 |
+
|
| 512 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 513 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 514 |
+
|
| 515 |
+
# 创建对应长度的mask序列
|
| 516 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 517 |
+
condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
|
| 518 |
+
mask[start_frame:condition_end] = 1.0
|
| 519 |
+
|
| 520 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 521 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 522 |
+
return camera_embedding.to(torch.bfloat16)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def generate_openx_camera_embeddings_sliding(
|
| 526 |
+
encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
|
| 527 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 528 |
+
time_compression_ratio = 4
|
| 529 |
+
|
| 530 |
+
# 计算FramePack实际需要的camera帧数
|
| 531 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 532 |
+
|
| 533 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 534 |
+
print("🔧 使用OpenX真实camera数据")
|
| 535 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 536 |
+
|
| 537 |
+
# 确保生成足够长的camera序列
|
| 538 |
+
max_needed_frames = max(
|
| 539 |
+
start_frame + initial_condition_frames + new_frames,
|
| 540 |
+
framepack_needed_frames,
|
| 541 |
+
30
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 545 |
+
print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
|
| 546 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 547 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 548 |
+
|
| 549 |
+
relative_poses = []
|
| 550 |
+
for i in range(max_needed_frames):
|
| 551 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 552 |
+
frame_idx = i * time_compression_ratio
|
| 553 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 554 |
+
|
| 555 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 556 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 557 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 558 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 559 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 560 |
+
else:
|
| 561 |
+
# 超出范围,使用零运动
|
| 562 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 563 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 564 |
+
|
| 565 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 566 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 567 |
+
|
| 568 |
+
# 创建对应长度的mask序列
|
| 569 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 570 |
+
# 从start_frame到start_frame + initial_condition_frames标记为condition
|
| 571 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 572 |
+
mask[start_frame:condition_end] = 1.0
|
| 573 |
+
|
| 574 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 575 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 576 |
+
return camera_embedding.to(torch.bfloat16)
|
| 577 |
+
|
| 578 |
+
else:
|
| 579 |
+
print("🔧 使用OpenX合成camera数据")
|
| 580 |
+
|
| 581 |
+
max_needed_frames = max(
|
| 582 |
+
start_frame + initial_condition_frames + new_frames,
|
| 583 |
+
framepack_needed_frames,
|
| 584 |
+
30
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 588 |
+
relative_poses = []
|
| 589 |
+
for i in range(max_needed_frames):
|
| 590 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 591 |
+
# 模拟机器人手臂的精细操作运动
|
| 592 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 593 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 594 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 595 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 596 |
+
|
| 597 |
+
pose = np.eye(4, dtype=np.float32)
|
| 598 |
+
|
| 599 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 600 |
+
# 绕X轴旋转(roll)
|
| 601 |
+
cos_roll = np.cos(roll_per_frame)
|
| 602 |
+
sin_roll = np.sin(roll_per_frame)
|
| 603 |
+
# 绕Y轴旋转(pitch��
|
| 604 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 605 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 606 |
+
# 绕Z轴旋转(yaw)
|
| 607 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 608 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 609 |
+
|
| 610 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 611 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 612 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 613 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 614 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 615 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 616 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 617 |
+
pose[2, 0] = -sin_pitch
|
| 618 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 619 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 620 |
+
|
| 621 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 622 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 623 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 624 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 625 |
+
|
| 626 |
+
relative_pose = pose[:3, :]
|
| 627 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 628 |
+
|
| 629 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 630 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 631 |
+
|
| 632 |
+
# 创建对应长度的mask序列
|
| 633 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 634 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 635 |
+
mask[start_frame:condition_end] = 1.0
|
| 636 |
+
|
| 637 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 638 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 639 |
+
return camera_embedding.to(torch.bfloat16)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def generate_nuscenes_camera_embeddings_sliding(
|
| 643 |
+
scene_info, start_frame, initial_condition_frames, new_frames):
|
| 644 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 645 |
+
time_compression_ratio = 4
|
| 646 |
+
|
| 647 |
+
# 计算FramePack实际需要的camera帧数
|
| 648 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 649 |
+
|
| 650 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 651 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 652 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 653 |
+
|
| 654 |
+
if len(keyframe_poses) == 0:
|
| 655 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 656 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 657 |
+
|
| 658 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 659 |
+
|
| 660 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 661 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 662 |
+
mask[start_frame:condition_end] = 1.0
|
| 663 |
+
|
| 664 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 665 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 666 |
+
return camera_embedding.to(torch.bfloat16)
|
| 667 |
+
|
| 668 |
+
# 使用第一个pose作为参考
|
| 669 |
+
reference_pose = keyframe_poses[0]
|
| 670 |
+
|
| 671 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 672 |
+
|
| 673 |
+
pose_vecs = []
|
| 674 |
+
for i in range(max_needed_frames):
|
| 675 |
+
if i < len(keyframe_poses):
|
| 676 |
+
current_pose = keyframe_poses[i]
|
| 677 |
+
|
| 678 |
+
# 计算相对位移
|
| 679 |
+
translation = torch.tensor(
|
| 680 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 681 |
+
dtype=torch.float32
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# 计算相对旋转(简化版本)
|
| 685 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 686 |
+
|
| 687 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 688 |
+
else:
|
| 689 |
+
# 超出范围,使用零pose
|
| 690 |
+
pose_vec = torch.cat([
|
| 691 |
+
torch.zeros(3, dtype=torch.float32),
|
| 692 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 693 |
+
], dim=0) # [7D]
|
| 694 |
+
|
| 695 |
+
pose_vecs.append(pose_vec)
|
| 696 |
+
|
| 697 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 698 |
+
|
| 699 |
+
# 创建mask
|
| 700 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 701 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 702 |
+
mask[start_frame:condition_end] = 1.0
|
| 703 |
+
|
| 704 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 705 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 706 |
+
return camera_embedding.to(torch.bfloat16)
|
| 707 |
+
|
| 708 |
+
else:
|
| 709 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 710 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 711 |
+
|
| 712 |
+
# 创建合成运动序列
|
| 713 |
+
pose_vecs = []
|
| 714 |
+
for i in range(max_needed_frames):
|
| 715 |
+
# 左转运动模式 - 类似城市驾驶中的左转弯
|
| 716 |
+
angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
|
| 717 |
+
radius = 15.0 # 较大的转弯半径,更符合汽车转弯
|
| 718 |
+
|
| 719 |
+
# 计算圆弧轨迹上的位置
|
| 720 |
+
x = radius * np.sin(angle)
|
| 721 |
+
y = 0.0 # 保持水平面运动
|
| 722 |
+
z = radius * (1 - np.cos(angle))
|
| 723 |
+
|
| 724 |
+
translation = torch.tensor([x, y, z], dtype=torch.float32)
|
| 725 |
+
|
| 726 |
+
# 车辆朝向 - 始终沿着轨迹切线方向
|
| 727 |
+
yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
|
| 728 |
+
# 四元数表示绕Y轴的旋转
|
| 729 |
+
rotation = torch.tensor([
|
| 730 |
+
np.cos(yaw/2), # w (实部)
|
| 731 |
+
0.0, # x
|
| 732 |
+
0.0, # y
|
| 733 |
+
np.sin(yaw/2) # z (虚部,绕Y轴)
|
| 734 |
+
], dtype=torch.float32)
|
| 735 |
+
|
| 736 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
|
| 737 |
+
pose_vecs.append(pose_vec)
|
| 738 |
+
|
| 739 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 740 |
+
|
| 741 |
+
# 创建mask
|
| 742 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 743 |
+
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 744 |
+
mask[start_frame:condition_end] = 1.0
|
| 745 |
+
|
| 746 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 747 |
+
print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
|
| 748 |
+
return camera_embedding.to(torch.bfloat16)
|
| 749 |
+
|
| 750 |
+
def prepare_framepack_sliding_window_with_camera_moe(
|
| 751 |
+
history_latents,
|
| 752 |
+
target_frames_to_generate,
|
| 753 |
+
camera_embedding_full,
|
| 754 |
+
start_frame,
|
| 755 |
+
modality_type,
|
| 756 |
+
max_history_frames=49):
|
| 757 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 758 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 759 |
+
C, T, H, W = history_latents.shape
|
| 760 |
+
|
| 761 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 762 |
+
# 1帧起始 + 16帧4x + 2帧2x + 1帧1x + target_frames_to_generate
|
| 763 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 764 |
+
indices = torch.arange(0, total_indices_length)
|
| 765 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 766 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 767 |
+
indices.split(split_sizes, dim=0)
|
| 768 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 769 |
+
|
| 770 |
+
# 检查camera长度是否足够
|
| 771 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 772 |
+
print(f"⚠️ camera_embedding长度不足,进行零补齐: 当前长度 {camera_embedding_full.shape[0]}, 需要长度 {total_indices_length}")
|
| 773 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 774 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 775 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 776 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 777 |
+
|
| 778 |
+
# 从完整camera序列中选取对应部分
|
| 779 |
+
combined_camera = torch.zeros(
|
| 780 |
+
total_indices_length,
|
| 781 |
+
camera_embedding_full.shape[1],
|
| 782 |
+
dtype=camera_embedding_full.dtype,
|
| 783 |
+
device=camera_embedding_full.device)
|
| 784 |
+
|
| 785 |
+
# 历史条件帧的相机位姿
|
| 786 |
+
history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
|
| 787 |
+
combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
|
| 788 |
+
|
| 789 |
+
# 目标帧的相机位姿
|
| 790 |
+
target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
|
| 791 |
+
combined_camera[19:19 + target_slice.shape[0], :] = target_slice
|
| 792 |
+
|
| 793 |
+
# 根据当前history length重新设置mask
|
| 794 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 795 |
+
|
| 796 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 797 |
+
if T > 0:
|
| 798 |
+
available_frames = min(T, 19)
|
| 799 |
+
start_pos = 19 - available_frames
|
| 800 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 801 |
+
|
| 802 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 803 |
+
print(f" - 历史帧数: {T}")
|
| 804 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 805 |
+
print(f" - 模态类型: {modality_type}")
|
| 806 |
+
|
| 807 |
+
# 处理latents
|
| 808 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 809 |
+
|
| 810 |
+
if T > 0:
|
| 811 |
+
available_frames = min(T, 19)
|
| 812 |
+
start_pos = 19 - available_frames
|
| 813 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 814 |
+
|
| 815 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 816 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 817 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 818 |
+
|
| 819 |
+
if T > 0:
|
| 820 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 821 |
+
else:
|
| 822 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 823 |
+
|
| 824 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 825 |
+
|
| 826 |
+
return {
|
| 827 |
+
'latent_indices': latent_indices,
|
| 828 |
+
'clean_latents': clean_latents,
|
| 829 |
+
'clean_latents_2x': clean_latents_2x,
|
| 830 |
+
'clean_latents_4x': clean_latents_4x,
|
| 831 |
+
'clean_latent_indices': clean_latent_indices,
|
| 832 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 833 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 834 |
+
'camera_embedding': combined_camera,
|
| 835 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 836 |
+
'current_length': T,
|
| 837 |
+
'next_length': T + target_frames_to_generate
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
def overlay_controls(frame_img, pose_vec, icons):
|
| 841 |
+
"""
|
| 842 |
+
根据相机位姿在帧上叠加控制图标(WASD 和箭头)
|
| 843 |
+
pose_vec: 12 个元素(展平的 3x4 矩阵)+ mask
|
| 844 |
+
"""
|
| 845 |
+
if pose_vec is None or np.all(pose_vec[:12] == 0):
|
| 846 |
+
return frame_img
|
| 847 |
+
|
| 848 |
+
# 提取平移向量(基于展平的 3x4 矩阵的索引)
|
| 849 |
+
# [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
|
| 850 |
+
tx = pose_vec[3]
|
| 851 |
+
# ty = pose_vec[7]
|
| 852 |
+
tz = pose_vec[11]
|
| 853 |
+
|
| 854 |
+
# 提取旋转(偏航和俯仰)
|
| 855 |
+
# 偏航:绕 Y 轴。sin(偏航) = r02, cos(偏航) = r00
|
| 856 |
+
r00 = pose_vec[0]
|
| 857 |
+
r02 = pose_vec[2]
|
| 858 |
+
yaw = np.arctan2(r02, r00)
|
| 859 |
+
|
| 860 |
+
# 俯仰:绕 X 轴。sin(俯仰) = -r12, cos(俯仰) = r22
|
| 861 |
+
r12 = pose_vec[6]
|
| 862 |
+
r22 = pose_vec[10]
|
| 863 |
+
pitch = np.arctan2(-r12, r22)
|
| 864 |
+
|
| 865 |
+
# 按键激活的阈值
|
| 866 |
+
TRANS_THRESH = 0.01
|
| 867 |
+
ROT_THRESH = 0.005
|
| 868 |
+
|
| 869 |
+
# 确定按键状态
|
| 870 |
+
# 平移(WASD)
|
| 871 |
+
# 假设 -Z 为前进,+X 为右
|
| 872 |
+
is_forward = tz < -TRANS_THRESH
|
| 873 |
+
is_backward = tz > TRANS_THRESH
|
| 874 |
+
is_left = tx < -TRANS_THRESH
|
| 875 |
+
is_right = tx > TRANS_THRESH
|
| 876 |
+
|
| 877 |
+
# 旋转(箭头)
|
| 878 |
+
# 偏航:+ 为左,- 为右
|
| 879 |
+
is_turn_left = yaw > ROT_THRESH
|
| 880 |
+
is_turn_right = yaw < -ROT_THRESH
|
| 881 |
+
|
| 882 |
+
# 俯仰:+ 为下,- 为上
|
| 883 |
+
is_turn_up = pitch < -ROT_THRESH
|
| 884 |
+
is_turn_down = pitch > ROT_THRESH
|
| 885 |
+
|
| 886 |
+
W, H = frame_img.size
|
| 887 |
+
spacing = 60
|
| 888 |
+
|
| 889 |
+
def paste_icon(name_active, name_inactive, is_active, x, y):
|
| 890 |
+
name = name_active if is_active else name_inactive
|
| 891 |
+
if name in icons:
|
| 892 |
+
icon = icons[name]
|
| 893 |
+
# 使用 alpha 通道粘贴
|
| 894 |
+
frame_img.paste(icon, (int(x), int(y)), icon)
|
| 895 |
+
|
| 896 |
+
# 叠加 WASD(左下角)
|
| 897 |
+
base_x_right = 100
|
| 898 |
+
base_y = H - 100
|
| 899 |
+
|
| 900 |
+
# W
|
| 901 |
+
paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing)
|
| 902 |
+
# A
|
| 903 |
+
paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y)
|
| 904 |
+
# S
|
| 905 |
+
paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y)
|
| 906 |
+
# D
|
| 907 |
+
paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
|
| 908 |
+
|
| 909 |
+
# 叠加 ↑↓←→(右下角)
|
| 910 |
+
base_x_left = W - 150
|
| 911 |
+
|
| 912 |
+
# ↑
|
| 913 |
+
paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing)
|
| 914 |
+
# ←
|
| 915 |
+
paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y)
|
| 916 |
+
# ↓
|
| 917 |
+
paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y)
|
| 918 |
+
# →
|
| 919 |
+
paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y)
|
| 920 |
+
|
| 921 |
+
return frame_img
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def inference_moe_framepack_sliding_window(
|
| 925 |
+
condition_pth_path,
|
| 926 |
+
dit_path,
|
| 927 |
+
output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
|
| 928 |
+
start_frame=0,
|
| 929 |
+
initial_condition_frames=8,
|
| 930 |
+
frames_per_generation=4,
|
| 931 |
+
total_frames_to_generate=32,
|
| 932 |
+
max_history_frames=49,
|
| 933 |
+
device="cuda",
|
| 934 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 935 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 936 |
+
use_real_poses=True,
|
| 937 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 938 |
+
# CFG参数
|
| 939 |
+
use_camera_cfg=True,
|
| 940 |
+
camera_guidance_scale=2.0,
|
| 941 |
+
text_guidance_scale=1.0,
|
| 942 |
+
# MoE参数
|
| 943 |
+
moe_num_experts=4,
|
| 944 |
+
moe_top_k=2,
|
| 945 |
+
moe_hidden_dim=None,
|
| 946 |
+
direction="left",
|
| 947 |
+
use_gt_prompt=True,
|
| 948 |
+
add_icons=False
|
| 949 |
+
):
|
| 950 |
+
"""
|
| 951 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 952 |
+
"""
|
| 953 |
+
# 创建输出目录
|
| 954 |
+
dir_path = os.path.dirname(output_path)
|
| 955 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 956 |
+
|
| 957 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 958 |
+
print(f"模态类型: {modality_type}")
|
| 959 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 960 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 961 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 962 |
+
|
| 963 |
+
# 1. 模型初始化
|
| 964 |
+
replace_dit_model_in_manager()
|
| 965 |
+
|
| 966 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 967 |
+
model_manager.load_models([
|
| 968 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 969 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 970 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 971 |
+
])
|
| 972 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 973 |
+
|
| 974 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 975 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 976 |
+
for block in pipe.dit.blocks:
|
| 977 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 978 |
+
block.projector = nn.Linear(dim, dim)
|
| 979 |
+
block.cam_encoder.weight.data.zero_()
|
| 980 |
+
block.cam_encoder.bias.data.zero_()
|
| 981 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 982 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 983 |
+
|
| 984 |
+
# 3. 添加FramePack组件
|
| 985 |
+
add_framepack_components(pipe.dit)
|
| 986 |
+
|
| 987 |
+
# 4. 添加MoE组件
|
| 988 |
+
moe_config = {
|
| 989 |
+
"num_experts": moe_num_experts,
|
| 990 |
+
"top_k": moe_top_k,
|
| 991 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 992 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 993 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 994 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 995 |
+
}
|
| 996 |
+
add_moe_components(pipe.dit, moe_config)
|
| 997 |
+
|
| 998 |
+
# 5. 加载训练好的权重
|
| 999 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 1000 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 1001 |
+
pipe = pipe.to(device)
|
| 1002 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 1003 |
+
|
| 1004 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 1005 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 1006 |
+
|
| 1007 |
+
# 设置去噪步数
|
| 1008 |
+
pipe.scheduler.set_timesteps(50)
|
| 1009 |
+
|
| 1010 |
+
# 6. 加载初始条件
|
| 1011 |
+
print("Loading initial condition frames...")
|
| 1012 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 1013 |
+
condition_pth_path,
|
| 1014 |
+
start_frame=start_frame,
|
| 1015 |
+
num_frames=initial_condition_frames
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# 空间裁剪
|
| 1019 |
+
target_height, target_width = 60, 104
|
| 1020 |
+
C, T, H, W = initial_latents.shape
|
| 1021 |
+
|
| 1022 |
+
if H > target_height or W > target_width:
|
| 1023 |
+
h_start = (H - target_height) // 2
|
| 1024 |
+
w_start = (W - target_width) // 2
|
| 1025 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 1026 |
+
H, W = target_height, target_width
|
| 1027 |
+
|
| 1028 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 1029 |
+
|
| 1030 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 1031 |
+
|
| 1032 |
+
# 7. 编码prompt - 支持CFG
|
| 1033 |
+
if use_gt_prompt and 'prompt_emb' in encoded_data:
|
| 1034 |
+
print("✅ 使用预编码的GT prompt embedding")
|
| 1035 |
+
prompt_emb_pos = encoded_data['prompt_emb']
|
| 1036 |
+
# 将prompt_emb移到正确的设备和数据类型
|
| 1037 |
+
if 'context' in prompt_emb_pos:
|
| 1038 |
+
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
|
| 1039 |
+
if 'context_mask' in prompt_emb_pos:
|
| 1040 |
+
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
|
| 1041 |
+
|
| 1042 |
+
# 如果使用Text CFG,生成负向prompt
|
| 1043 |
+
if text_guidance_scale > 1.0:
|
| 1044 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 1045 |
+
print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
|
| 1046 |
+
else:
|
| 1047 |
+
prompt_emb_neg = None
|
| 1048 |
+
print("不使用Text CFG")
|
| 1049 |
+
|
| 1050 |
+
# 🔧 打印GT prompt文本(如果有)
|
| 1051 |
+
if 'prompt' in encoded_data['prompt_emb']:
|
| 1052 |
+
gt_prompt_text = encoded_data['prompt_emb']['prompt']
|
| 1053 |
+
print(f"📝 GT Prompt文本: {gt_prompt_text}")
|
| 1054 |
+
else:
|
| 1055 |
+
# 使用传入的prompt参数重新编码
|
| 1056 |
+
print(f"🔄 重新编码prompt: {prompt}")
|
| 1057 |
+
if text_guidance_scale > 1.0:
|
| 1058 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1059 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 1060 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 1061 |
+
else:
|
| 1062 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1063 |
+
prompt_emb_neg = None
|
| 1064 |
+
print("不使用Text CFG")
|
| 1065 |
+
|
| 1066 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 1067 |
+
scene_info = None
|
| 1068 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 1069 |
+
with open(scene_info_path, 'r') as f:
|
| 1070 |
+
scene_info = json.load(f)
|
| 1071 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 1072 |
+
|
| 1073 |
+
# 9. 预生成完整的camera embedding序列
|
| 1074 |
+
if modality_type == "sekai":
|
| 1075 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 1076 |
+
encoded_data.get('cam_emb', None),
|
| 1077 |
+
start_frame,
|
| 1078 |
+
initial_condition_frames,
|
| 1079 |
+
total_frames_to_generate,
|
| 1080 |
+
0,
|
| 1081 |
+
use_real_poses=use_real_poses,
|
| 1082 |
+
direction=direction
|
| 1083 |
+
).to(device, dtype=model_dtype)
|
| 1084 |
+
elif modality_type == "nuscenes":
|
| 1085 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 1086 |
+
scene_info,
|
| 1087 |
+
start_frame,
|
| 1088 |
+
initial_condition_frames,
|
| 1089 |
+
total_frames_to_generate
|
| 1090 |
+
).to(device, dtype=model_dtype)
|
| 1091 |
+
elif modality_type == "openx":
|
| 1092 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 1093 |
+
encoded_data,
|
| 1094 |
+
start_frame,
|
| 1095 |
+
initial_condition_frames,
|
| 1096 |
+
total_frames_to_generate,
|
| 1097 |
+
use_real_poses=use_real_poses
|
| 1098 |
+
).to(device, dtype=model_dtype)
|
| 1099 |
+
else:
|
| 1100 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 1101 |
+
|
| 1102 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 1103 |
+
|
| 1104 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 1105 |
+
if use_camera_cfg:
|
| 1106 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 1107 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 1108 |
+
|
| 1109 |
+
# 11. 滑动窗口生成循环
|
| 1110 |
+
total_generated = 0
|
| 1111 |
+
all_generated_frames = []
|
| 1112 |
+
|
| 1113 |
+
while total_generated < total_frames_to_generate:
|
| 1114 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 1115 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 1116 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 1117 |
+
|
| 1118 |
+
# FramePack数据准备 - MoE版本
|
| 1119 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 1120 |
+
history_latents,
|
| 1121 |
+
current_generation,
|
| 1122 |
+
camera_embedding_full,
|
| 1123 |
+
start_frame,
|
| 1124 |
+
modality_type,
|
| 1125 |
+
max_history_frames
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
# 准备输入
|
| 1129 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 1130 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 1131 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 1132 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 1133 |
+
|
| 1134 |
+
# 准备modality_inputs
|
| 1135 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 1136 |
+
|
| 1137 |
+
# 为CFG准备无条件camera embedding
|
| 1138 |
+
if use_camera_cfg:
|
| 1139 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 1140 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 1141 |
+
|
| 1142 |
+
# 索引处理
|
| 1143 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 1144 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 1145 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 1146 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 1147 |
+
|
| 1148 |
+
# 初始化要生成的latents
|
| 1149 |
+
new_latents = torch.randn(
|
| 1150 |
+
1, C, current_generation, H, W,
|
| 1151 |
+
device=device, dtype=model_dtype
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 1155 |
+
|
| 1156 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 1157 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 1158 |
+
|
| 1159 |
+
# 去噪循环 - 支持CFG
|
| 1160 |
+
timesteps = pipe.scheduler.timesteps
|
| 1161 |
+
|
| 1162 |
+
for i, timestep in enumerate(timesteps):
|
| 1163 |
+
if i % 10 == 0:
|
| 1164 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 1165 |
+
|
| 1166 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 1167 |
+
|
| 1168 |
+
with torch.no_grad():
|
| 1169 |
+
# CFG推理
|
| 1170 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 1171 |
+
# 条件预测(有camera)
|
| 1172 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 1173 |
+
new_latents,
|
| 1174 |
+
timestep=timestep_tensor,
|
| 1175 |
+
cam_emb=camera_embedding,
|
| 1176 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 1177 |
+
latent_indices=latent_indices,
|
| 1178 |
+
clean_latents=clean_latents,
|
| 1179 |
+
clean_latent_indices=clean_latent_indices,
|
| 1180 |
+
clean_latents_2x=clean_latents_2x,
|
| 1181 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1182 |
+
clean_latents_4x=clean_latents_4x,
|
| 1183 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1184 |
+
**prompt_emb_pos,
|
| 1185 |
+
**extra_input
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
# 无条件预测(无camera)
|
| 1189 |
+
noise_pred_uncond, moe_loess = pipe.dit(
|
| 1190 |
+
new_latents,
|
| 1191 |
+
timestep=timestep_tensor,
|
| 1192 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 1193 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 1194 |
+
latent_indices=latent_indices,
|
| 1195 |
+
clean_latents=clean_latents,
|
| 1196 |
+
clean_latent_indices=clean_latent_indices,
|
| 1197 |
+
clean_latents_2x=clean_latents_2x,
|
| 1198 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1199 |
+
clean_latents_4x=clean_latents_4x,
|
| 1200 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1201 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 1202 |
+
**extra_input
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
# Camera CFG
|
| 1206 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1207 |
+
|
| 1208 |
+
# 如果同时使用Text CFG
|
| 1209 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1210 |
+
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 1211 |
+
new_latents,
|
| 1212 |
+
timestep=timestep_tensor,
|
| 1213 |
+
cam_emb=camera_embedding,
|
| 1214 |
+
modality_inputs=modality_inputs,
|
| 1215 |
+
latent_indices=latent_indices,
|
| 1216 |
+
clean_latents=clean_latents,
|
| 1217 |
+
clean_latent_indices=clean_latent_indices,
|
| 1218 |
+
clean_latents_2x=clean_latents_2x,
|
| 1219 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1220 |
+
clean_latents_4x=clean_latents_4x,
|
| 1221 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1222 |
+
**prompt_emb_neg,
|
| 1223 |
+
**extra_input
|
| 1224 |
+
)
|
| 1225 |
+
|
| 1226 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 1227 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 1228 |
+
|
| 1229 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1230 |
+
# 只使用Text CFG
|
| 1231 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 1232 |
+
new_latents,
|
| 1233 |
+
timestep=timestep_tensor,
|
| 1234 |
+
cam_emb=camera_embedding,
|
| 1235 |
+
modality_inputs=modality_inputs,
|
| 1236 |
+
latent_indices=latent_indices,
|
| 1237 |
+
clean_latents=clean_latents,
|
| 1238 |
+
clean_latent_indices=clean_latent_indices,
|
| 1239 |
+
clean_latents_2x=clean_latents_2x,
|
| 1240 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1241 |
+
clean_latents_4x=clean_latents_4x,
|
| 1242 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1243 |
+
**prompt_emb_pos,
|
| 1244 |
+
**extra_input
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
noise_pred_uncond, moe_loess= pipe.dit(
|
| 1248 |
+
new_latents,
|
| 1249 |
+
timestep=timestep_tensor,
|
| 1250 |
+
cam_emb=camera_embedding,
|
| 1251 |
+
modality_inputs=modality_inputs,
|
| 1252 |
+
latent_indices=latent_indices,
|
| 1253 |
+
clean_latents=clean_latents,
|
| 1254 |
+
clean_latent_indices=clean_latent_indices,
|
| 1255 |
+
clean_latents_2x=clean_latents_2x,
|
| 1256 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1257 |
+
clean_latents_4x=clean_latents_4x,
|
| 1258 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1259 |
+
**prompt_emb_neg,
|
| 1260 |
+
**extra_input
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1264 |
+
|
| 1265 |
+
else:
|
| 1266 |
+
# 标准推理(无CFG)
|
| 1267 |
+
noise_pred, moe_loess = pipe.dit(
|
| 1268 |
+
new_latents,
|
| 1269 |
+
timestep=timestep_tensor,
|
| 1270 |
+
cam_emb=camera_embedding,
|
| 1271 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 1272 |
+
latent_indices=latent_indices,
|
| 1273 |
+
clean_latents=clean_latents,
|
| 1274 |
+
clean_latent_indices=clean_latent_indices,
|
| 1275 |
+
clean_latents_2x=clean_latents_2x,
|
| 1276 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1277 |
+
clean_latents_4x=clean_latents_4x,
|
| 1278 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1279 |
+
**prompt_emb_pos,
|
| 1280 |
+
**extra_input
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 1284 |
+
|
| 1285 |
+
# 更新历史
|
| 1286 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 1287 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 1288 |
+
|
| 1289 |
+
# 维护滑动窗口
|
| 1290 |
+
if history_latents.shape[1] > max_history_frames:
|
| 1291 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 1292 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 1293 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 1294 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 1295 |
+
|
| 1296 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 1297 |
+
|
| 1298 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 1299 |
+
total_generated += current_generation
|
| 1300 |
+
|
| 1301 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 1302 |
+
|
| 1303 |
+
# 12. 解码和保存
|
| 1304 |
+
print("\n🔧 解码生成的视频...")
|
| 1305 |
+
|
| 1306 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 1307 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 1308 |
+
|
| 1309 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 1310 |
+
|
| 1311 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 1312 |
+
|
| 1313 |
+
print(f"Saving video to {output_path} ...")
|
| 1314 |
+
|
| 1315 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 1316 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 1317 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 1318 |
+
|
| 1319 |
+
icons = {}
|
| 1320 |
+
video_camera_poses = None
|
| 1321 |
+
if add_icons:
|
| 1322 |
+
# 加载用于叠加的图标资源
|
| 1323 |
+
icons_dir = os.path.join(ROOT_DIR, 'icons')
|
| 1324 |
+
icon_names = ['move_forward.png', 'not_move_forward.png',
|
| 1325 |
+
'move_backward.png', 'not_move_backward.png',
|
| 1326 |
+
'move_left.png', 'not_move_left.png',
|
| 1327 |
+
'move_right.png', 'not_move_right.png',
|
| 1328 |
+
'turn_up.png', 'not_turn_up.png',
|
| 1329 |
+
'turn_down.png', 'not_turn_down.png',
|
| 1330 |
+
'turn_left.png', 'not_turn_left.png',
|
| 1331 |
+
'turn_right.png', 'not_turn_right.png']
|
| 1332 |
+
for name in icon_names:
|
| 1333 |
+
path = os.path.join(icons_dir, name)
|
| 1334 |
+
if os.path.exists(path):
|
| 1335 |
+
try:
|
| 1336 |
+
icon = Image.open(path).convert("RGBA")
|
| 1337 |
+
# 调整图标尺寸
|
| 1338 |
+
icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
|
| 1339 |
+
icons[name] = icon
|
| 1340 |
+
except Exception as e:
|
| 1341 |
+
print(f"Error loading icon {name}: {e}")
|
| 1342 |
+
else:
|
| 1343 |
+
print(f"Warning: Icon {name} not found at {path}")
|
| 1344 |
+
|
| 1345 |
+
# 获取与视频帧对应的相机姿态
|
| 1346 |
+
time_compression_ratio = 4
|
| 1347 |
+
camera_poses = camera_embedding_full.detach().float().cpu().numpy()
|
| 1348 |
+
video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
|
| 1349 |
+
|
| 1350 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 1351 |
+
for i, frame in enumerate(video_np):
|
| 1352 |
+
# Convert to PIL for overlay
|
| 1353 |
+
img = Image.fromarray(frame)
|
| 1354 |
+
|
| 1355 |
+
if add_icons and video_camera_poses is not None and icons:
|
| 1356 |
+
# Video frame i corresponds to camera_embedding_full[start_frame + i]
|
| 1357 |
+
pose_idx = start_frame + i
|
| 1358 |
+
if pose_idx < len(video_camera_poses):
|
| 1359 |
+
pose_vec = video_camera_poses[pose_idx]
|
| 1360 |
+
img = overlay_controls(img, pose_vec, icons)
|
| 1361 |
+
|
| 1362 |
+
writer.append_data(np.array(img))
|
| 1363 |
+
|
| 1364 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 1365 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 1366 |
+
print(f"使用模态: {modality_type}")
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
def main():
|
| 1370 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 1371 |
+
|
| 1372 |
+
# 基础参数
|
| 1373 |
+
parser.add_argument("--condition_pth", type=str,
|
| 1374 |
+
default="../examples/condition_pth/garden_1.pth")
|
| 1375 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 1376 |
+
parser.add_argument("--initial_condition_frames", type=int, default=1)
|
| 1377 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 1378 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 1379 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 1380 |
+
parser.add_argument("--use_real_poses", default=False)
|
| 1381 |
+
parser.add_argument("--dit_path", type=str, default=None, required=True,
|
| 1382 |
+
help="path to the pretrained DiT MoE model checkpoint")
|
| 1383 |
+
parser.add_argument("--output_path", type=str,
|
| 1384 |
+
default='./examples/output_videos/output_moe_framepack_sliding.mp4')
|
| 1385 |
+
parser.add_argument("--prompt", type=str, default=None,
|
| 1386 |
+
help="text prompt for video generation")
|
| 1387 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 1388 |
+
parser.add_argument("--add_icons", action="store_true", default=False,
|
| 1389 |
+
help="在生成的视频上叠加控制图标")
|
| 1390 |
+
|
| 1391 |
+
# 模态类型参数
|
| 1392 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
|
| 1393 |
+
default="sekai", help="模态类型:sekai 或 nuscenes 或 openx")
|
| 1394 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 1395 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 1396 |
+
|
| 1397 |
+
# CFG参数
|
| 1398 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 1399 |
+
help="使用Camera CFG")
|
| 1400 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 1401 |
+
help="Camera guidance scale for CFG")
|
| 1402 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 1403 |
+
help="Text guidance scale for CFG")
|
| 1404 |
+
|
| 1405 |
+
# MoE参数
|
| 1406 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 1407 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 1408 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 1409 |
+
parser.add_argument("--direction", type=str, default="left", help="生成视频的行进轨迹方向")
|
| 1410 |
+
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
|
| 1411 |
+
help="使用数据集中的ground truth prompt embedding")
|
| 1412 |
+
|
| 1413 |
+
args = parser.parse_args()
|
| 1414 |
+
|
| 1415 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 1416 |
+
print(f"模态类型: {args.modality_type}")
|
| 1417 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 1418 |
+
if args.use_camera_cfg:
|
| 1419 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 1420 |
+
print(f"使用GT Prompt: {args.use_gt_prompt}")
|
| 1421 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 1422 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 1423 |
+
print(f"DiT{args.dit_path}")
|
| 1424 |
+
|
| 1425 |
+
# 验证NuScenes参数
|
| 1426 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 1427 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 1428 |
+
|
| 1429 |
+
inference_moe_framepack_sliding_window(
|
| 1430 |
+
condition_pth_path=args.condition_pth,
|
| 1431 |
+
dit_path=args.dit_path,
|
| 1432 |
+
output_path=args.output_path,
|
| 1433 |
+
start_frame=args.start_frame,
|
| 1434 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 1435 |
+
frames_per_generation=args.frames_per_generation,
|
| 1436 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 1437 |
+
max_history_frames=args.max_history_frames,
|
| 1438 |
+
device=args.device,
|
| 1439 |
+
prompt=args.prompt,
|
| 1440 |
+
modality_type=args.modality_type,
|
| 1441 |
+
use_real_poses=args.use_real_poses,
|
| 1442 |
+
scene_info_path=args.scene_info_path,
|
| 1443 |
+
# CFG参数
|
| 1444 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 1445 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 1446 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1447 |
+
# MoE参数
|
| 1448 |
+
moe_num_experts=args.moe_num_experts,
|
| 1449 |
+
moe_top_k=args.moe_top_k,
|
| 1450 |
+
moe_hidden_dim=args.moe_hidden_dim,
|
| 1451 |
+
direction=args.direction,
|
| 1452 |
+
use_gt_prompt=args.use_gt_prompt,
|
| 1453 |
+
add_icons=args.add_icons
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
|
| 1457 |
+
if __name__ == "__main__":
|
| 1458 |
+
main()
|
scripts/infer_moe.py
ADDED
|
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
from scipy.spatial.transform import Rotation as R
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 17 |
+
"""
|
| 18 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 19 |
+
|
| 20 |
+
参数:
|
| 21 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 22 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 23 |
+
|
| 24 |
+
返回:
|
| 25 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 26 |
+
"""
|
| 27 |
+
# 分离平移向量和四元数
|
| 28 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 29 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 30 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 31 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 32 |
+
|
| 33 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 34 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 35 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 36 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 37 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 38 |
+
|
| 39 |
+
# 2. 计算相对平移向量 t_rel
|
| 40 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 41 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 42 |
+
|
| 43 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 44 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 45 |
+
|
| 46 |
+
return relative_matrix
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 50 |
+
"""计算相对旋转四元数 - NuScenes专用"""
|
| 51 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 52 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 53 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 54 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 55 |
+
w2, x2, y2, z2 = q_current
|
| 56 |
+
relative_rotation = torch.tensor([
|
| 57 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 58 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 59 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 60 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 61 |
+
])
|
| 62 |
+
return relative_rotation
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 66 |
+
"""从pth文件加载预编码的视频数据"""
|
| 67 |
+
print(f"Loading encoded video from {pth_path}")
|
| 68 |
+
|
| 69 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 70 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 71 |
+
|
| 72 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 73 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 74 |
+
|
| 75 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 76 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 77 |
+
|
| 78 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 79 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 80 |
+
|
| 81 |
+
return condition_latents, encoded_data
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 85 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 86 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 87 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 88 |
+
|
| 89 |
+
if use_torch:
|
| 90 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 91 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 92 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 93 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 94 |
+
|
| 95 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 96 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 97 |
+
else:
|
| 98 |
+
if not isinstance(pose_a, np.ndarray):
|
| 99 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 100 |
+
if not isinstance(pose_b, np.ndarray):
|
| 101 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 104 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 105 |
+
|
| 106 |
+
return relative_pose
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def replace_dit_model_in_manager():
|
| 110 |
+
"""替换DiT模型类为MoE版本"""
|
| 111 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 112 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 113 |
+
|
| 114 |
+
for i, config in enumerate(model_loader_configs):
|
| 115 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 116 |
+
|
| 117 |
+
if 'wan_video_dit' in model_names:
|
| 118 |
+
new_model_names = []
|
| 119 |
+
new_model_classes = []
|
| 120 |
+
|
| 121 |
+
for name, cls in zip(model_names, model_classes):
|
| 122 |
+
if name == 'wan_video_dit':
|
| 123 |
+
new_model_names.append(name)
|
| 124 |
+
new_model_classes.append(WanModelMoe)
|
| 125 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 126 |
+
else:
|
| 127 |
+
new_model_names.append(name)
|
| 128 |
+
new_model_classes.append(cls)
|
| 129 |
+
|
| 130 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def add_framepack_components(dit_model):
|
| 134 |
+
"""添加FramePack相关组件"""
|
| 135 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 136 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 137 |
+
|
| 138 |
+
class CleanXEmbedder(nn.Module):
|
| 139 |
+
def __init__(self, inner_dim):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 142 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 143 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 144 |
+
|
| 145 |
+
def forward(self, x, scale="1x"):
|
| 146 |
+
if scale == "1x":
|
| 147 |
+
x = x.to(self.proj.weight.dtype)
|
| 148 |
+
return self.proj(x)
|
| 149 |
+
elif scale == "2x":
|
| 150 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 151 |
+
return self.proj_2x(x)
|
| 152 |
+
elif scale == "4x":
|
| 153 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 154 |
+
return self.proj_4x(x)
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 157 |
+
|
| 158 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 159 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 160 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 161 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def add_moe_components(dit_model, moe_config):
|
| 165 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 166 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 167 |
+
dit_model.moe_config = moe_config
|
| 168 |
+
print("✅ 添加了MoE配置到模型")
|
| 169 |
+
dit_model.top_k = moe_config.get("top_k", 1)
|
| 170 |
+
|
| 171 |
+
# 为每个block动态添加MoE组件
|
| 172 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 173 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 174 |
+
num_experts = moe_config.get("num_experts", 4)
|
| 175 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 176 |
+
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 177 |
+
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 178 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 179 |
+
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
for i, block in enumerate(dit_model.blocks):
|
| 183 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 184 |
+
block.moe = MultiModalMoE(
|
| 185 |
+
unified_dim=unified_dim,
|
| 186 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 187 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 188 |
+
top_k=moe_config.get("top_k", 2)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 195 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 196 |
+
time_compression_ratio = 4
|
| 197 |
+
|
| 198 |
+
# 计算FramePack实际需要的camera帧数
|
| 199 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 200 |
+
|
| 201 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 202 |
+
print("🔧 使用真实Sekai camera数据")
|
| 203 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 204 |
+
|
| 205 |
+
# 确保生成足够长的camera序列
|
| 206 |
+
max_needed_frames = max(
|
| 207 |
+
start_frame + current_history_length + new_frames,
|
| 208 |
+
framepack_needed_frames,
|
| 209 |
+
30
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 213 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 214 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 215 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 216 |
+
|
| 217 |
+
relative_poses = []
|
| 218 |
+
for i in range(max_needed_frames):
|
| 219 |
+
# 计算当前帧在原始序列中的位置
|
| 220 |
+
frame_idx = i * time_compression_ratio
|
| 221 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 222 |
+
|
| 223 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 224 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 225 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 226 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 227 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 228 |
+
else:
|
| 229 |
+
# 超出范围,使用零运动
|
| 230 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 231 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 232 |
+
|
| 233 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 234 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 235 |
+
|
| 236 |
+
# 创建对应长度的mask序列
|
| 237 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 238 |
+
# 从start_frame到current_history_length标记为condition
|
| 239 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 240 |
+
mask[start_frame:condition_end] = 1.0
|
| 241 |
+
|
| 242 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 243 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 244 |
+
return camera_embedding.to(torch.bfloat16)
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
print("🔧 使用Sekai合成camera数据")
|
| 248 |
+
|
| 249 |
+
max_needed_frames = max(
|
| 250 |
+
start_frame + current_history_length + new_frames,
|
| 251 |
+
framepack_needed_frames,
|
| 252 |
+
30
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 256 |
+
relative_poses = []
|
| 257 |
+
for i in range(max_needed_frames):
|
| 258 |
+
# 持续左转运动模式
|
| 259 |
+
yaw_per_frame = -0.1 # 每帧左转(正角度表示左转)
|
| 260 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 261 |
+
|
| 262 |
+
pose = np.eye(4, dtype=np.float32)
|
| 263 |
+
|
| 264 |
+
# 旋转矩阵(绕Y轴左转)
|
| 265 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 266 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 267 |
+
|
| 268 |
+
pose[0, 0] = cos_yaw
|
| 269 |
+
pose[0, 2] = sin_yaw
|
| 270 |
+
pose[2, 0] = -sin_yaw
|
| 271 |
+
pose[2, 2] = cos_yaw
|
| 272 |
+
|
| 273 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 274 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 275 |
+
|
| 276 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 277 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 278 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 279 |
+
|
| 280 |
+
relative_pose = pose[:3, :]
|
| 281 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 282 |
+
|
| 283 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 284 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 285 |
+
|
| 286 |
+
# 创建对应长度的mask序列
|
| 287 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 288 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 289 |
+
mask[start_frame:condition_end] = 1.0
|
| 290 |
+
|
| 291 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 292 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 293 |
+
return camera_embedding.to(torch.bfloat16)
|
| 294 |
+
|
| 295 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 296 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 297 |
+
time_compression_ratio = 4
|
| 298 |
+
|
| 299 |
+
# 计算FramePack实际需要的camera帧数
|
| 300 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 301 |
+
|
| 302 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 303 |
+
print("🔧 使用OpenX真实camera数据")
|
| 304 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 305 |
+
|
| 306 |
+
# 确保生成足够长的camera序列
|
| 307 |
+
max_needed_frames = max(
|
| 308 |
+
start_frame + current_history_length + new_frames,
|
| 309 |
+
framepack_needed_frames,
|
| 310 |
+
30
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 314 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 315 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 316 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 317 |
+
|
| 318 |
+
relative_poses = []
|
| 319 |
+
for i in range(max_needed_frames):
|
| 320 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 321 |
+
frame_idx = i * time_compression_ratio
|
| 322 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 323 |
+
|
| 324 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 325 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 326 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 327 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 328 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 329 |
+
else:
|
| 330 |
+
# 超出范围,使用零运动
|
| 331 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 332 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 333 |
+
|
| 334 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 335 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 336 |
+
|
| 337 |
+
# 创建对应长度的mask序列
|
| 338 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 339 |
+
# 从start_frame到current_history_length标记为condition
|
| 340 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 341 |
+
mask[start_frame:condition_end] = 1.0
|
| 342 |
+
|
| 343 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 344 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 345 |
+
return camera_embedding.to(torch.bfloat16)
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
print("🔧 使用OpenX合成camera数据")
|
| 349 |
+
|
| 350 |
+
max_needed_frames = max(
|
| 351 |
+
start_frame + current_history_length + new_frames,
|
| 352 |
+
framepack_needed_frames,
|
| 353 |
+
30
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 357 |
+
relative_poses = []
|
| 358 |
+
for i in range(max_needed_frames):
|
| 359 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 360 |
+
# 模拟机器人手臂的精细操作运动
|
| 361 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 362 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 363 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 364 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 365 |
+
|
| 366 |
+
pose = np.eye(4, dtype=np.float32)
|
| 367 |
+
|
| 368 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 369 |
+
# 绕X轴旋转(roll)
|
| 370 |
+
cos_roll = np.cos(roll_per_frame)
|
| 371 |
+
sin_roll = np.sin(roll_per_frame)
|
| 372 |
+
# 绕Y轴旋转(pitch)
|
| 373 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 374 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 375 |
+
# 绕Z轴旋转(yaw)
|
| 376 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 377 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 378 |
+
|
| 379 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 380 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 381 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 382 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 383 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 384 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 385 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 386 |
+
pose[2, 0] = -sin_pitch
|
| 387 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 388 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 389 |
+
|
| 390 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 391 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 392 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 393 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 394 |
+
|
| 395 |
+
relative_pose = pose[:3, :]
|
| 396 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 397 |
+
|
| 398 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 399 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 400 |
+
|
| 401 |
+
# 创建对应长度的mask序列
|
| 402 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 403 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 404 |
+
mask[start_frame:condition_end] = 1.0
|
| 405 |
+
|
| 406 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 407 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 408 |
+
return camera_embedding.to(torch.bfloat16)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 412 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 413 |
+
time_compression_ratio = 4
|
| 414 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 415 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 416 |
+
|
| 417 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 418 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 419 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 420 |
+
# 生成所有需要的关键帧索引
|
| 421 |
+
keyframe_indices = []
|
| 422 |
+
for i in range(max_needed_frames + 1): # +1是因为需要前后两帧
|
| 423 |
+
idx = (start_frame + i) * time_compression_ratio
|
| 424 |
+
keyframe_indices.append(idx)
|
| 425 |
+
keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices]
|
| 426 |
+
|
| 427 |
+
pose_vecs = []
|
| 428 |
+
for i in range(max_needed_frames):
|
| 429 |
+
pose_prev = keyframe_poses[keyframe_indices[i]]
|
| 430 |
+
pose_next = keyframe_poses[keyframe_indices[i+1]]
|
| 431 |
+
# 计算相对位移
|
| 432 |
+
translation = torch.tensor(
|
| 433 |
+
np.array(pose_next['translation']) - np.array(pose_prev['translation']),
|
| 434 |
+
dtype=torch.float32
|
| 435 |
+
)
|
| 436 |
+
# 计算相对旋转
|
| 437 |
+
relative_rotation = calculate_relative_rotation(
|
| 438 |
+
pose_next['rotation'],
|
| 439 |
+
pose_prev['rotation']
|
| 440 |
+
)
|
| 441 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
|
| 442 |
+
pose_vecs.append(pose_vec)
|
| 443 |
+
|
| 444 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 445 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 446 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 447 |
+
mask[start_frame:condition_end] = 1.0
|
| 448 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1)
|
| 449 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 450 |
+
return camera_embedding.to(torch.bfloat16)
|
| 451 |
+
|
| 452 |
+
else:
|
| 453 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 454 |
+
# 先生成绝对轨迹
|
| 455 |
+
abs_translations = []
|
| 456 |
+
abs_rotations = []
|
| 457 |
+
for i in range(max_needed_frames + 1): # +1是为了后续做相对
|
| 458 |
+
angle = -i * 0.12
|
| 459 |
+
radius = 8.0
|
| 460 |
+
x = radius * np.sin(angle)
|
| 461 |
+
y = 0.0
|
| 462 |
+
z = radius * (1 - np.cos(angle))
|
| 463 |
+
abs_translations.append(np.array([x, y, z], dtype=np.float32))
|
| 464 |
+
yaw = angle + np.pi/2
|
| 465 |
+
abs_rotations.append(np.array([
|
| 466 |
+
np.cos(yaw/2), 0.0, 0.0, np.sin(yaw/2)
|
| 467 |
+
], dtype=np.float32))
|
| 468 |
+
|
| 469 |
+
# 计算每帧相对上一帧的运动
|
| 470 |
+
pose_vecs = []
|
| 471 |
+
for i in range(max_needed_frames):
|
| 472 |
+
translation = torch.tensor(abs_translations[i+1] - abs_translations[i], dtype=torch.float32)
|
| 473 |
+
# 计算相对旋转
|
| 474 |
+
q_next = abs_rotations[i+1]
|
| 475 |
+
q_prev = abs_rotations[i]
|
| 476 |
+
# 四元数相对旋转
|
| 477 |
+
q_prev_inv = np.array([q_prev[0], -q_prev[1], -q_prev[2], -q_prev[3]], dtype=np.float32)
|
| 478 |
+
w1, x1, y1, z1 = q_prev_inv
|
| 479 |
+
w2, x2, y2, z2 = q_next
|
| 480 |
+
relative_rotation = torch.tensor([
|
| 481 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 482 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 483 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 484 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 485 |
+
], dtype=torch.float32)
|
| 486 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
|
| 487 |
+
pose_vecs.append(pose_vec)
|
| 488 |
+
|
| 489 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 490 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 491 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 492 |
+
mask[start_frame:condition_end] = 1.0
|
| 493 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1)
|
| 494 |
+
print(f"🔧 NuScenes合成相对pose embedding shape: {camera_embedding.shape}")
|
| 495 |
+
return camera_embedding.to(torch.bfloat16)
|
| 496 |
+
|
| 497 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 498 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 499 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 500 |
+
C, T, H, W = history_latents.shape
|
| 501 |
+
|
| 502 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 503 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 504 |
+
indices = torch.arange(0, total_indices_length)
|
| 505 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 506 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 507 |
+
indices.split(split_sizes, dim=0)
|
| 508 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 509 |
+
|
| 510 |
+
# 检查camera长度是否足够
|
| 511 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 512 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 513 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 514 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 515 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 516 |
+
|
| 517 |
+
# 从完整camera序列中选取对应部分
|
| 518 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 519 |
+
|
| 520 |
+
# 根据当前history length重新设置mask
|
| 521 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 522 |
+
|
| 523 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 524 |
+
if T > 0:
|
| 525 |
+
available_frames = min(T, 19)
|
| 526 |
+
start_pos = 19 - available_frames
|
| 527 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 528 |
+
|
| 529 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 530 |
+
print(f" - 历史帧数: {T}")
|
| 531 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 532 |
+
print(f" - 模态类型: {modality_type}")
|
| 533 |
+
|
| 534 |
+
# 处理latents
|
| 535 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 536 |
+
|
| 537 |
+
if T > 0:
|
| 538 |
+
available_frames = min(T, 19)
|
| 539 |
+
start_pos = 19 - available_frames
|
| 540 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 541 |
+
|
| 542 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 543 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 544 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 545 |
+
|
| 546 |
+
if T > 0:
|
| 547 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 548 |
+
else:
|
| 549 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 550 |
+
|
| 551 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 552 |
+
|
| 553 |
+
return {
|
| 554 |
+
'latent_indices': latent_indices,
|
| 555 |
+
'clean_latents': clean_latents,
|
| 556 |
+
'clean_latents_2x': clean_latents_2x,
|
| 557 |
+
'clean_latents_4x': clean_latents_4x,
|
| 558 |
+
'clean_latent_indices': clean_latent_indices,
|
| 559 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 560 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 561 |
+
'camera_embedding': combined_camera,
|
| 562 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 563 |
+
'current_length': T,
|
| 564 |
+
'next_length': T + target_frames_to_generate
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def inference_moe_framepack_sliding_window(
|
| 569 |
+
condition_pth_path,
|
| 570 |
+
dit_path,
|
| 571 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 572 |
+
start_frame=0,
|
| 573 |
+
initial_condition_frames=8,
|
| 574 |
+
frames_per_generation=4,
|
| 575 |
+
total_frames_to_generate=32,
|
| 576 |
+
max_history_frames=49,
|
| 577 |
+
device="cuda",
|
| 578 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 579 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 580 |
+
use_real_poses=True,
|
| 581 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 582 |
+
# CFG参数
|
| 583 |
+
use_camera_cfg=True,
|
| 584 |
+
camera_guidance_scale=2.0,
|
| 585 |
+
text_guidance_scale=1.0,
|
| 586 |
+
# MoE参数
|
| 587 |
+
moe_num_experts=4,
|
| 588 |
+
moe_top_k=2,
|
| 589 |
+
moe_hidden_dim=None
|
| 590 |
+
):
|
| 591 |
+
"""
|
| 592 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 593 |
+
"""
|
| 594 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 595 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 596 |
+
print(f"模态类型: {modality_type}")
|
| 597 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 598 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 599 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 600 |
+
|
| 601 |
+
# 1. 模型初始化
|
| 602 |
+
replace_dit_model_in_manager()
|
| 603 |
+
|
| 604 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 605 |
+
model_manager.load_models([
|
| 606 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 607 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 608 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 609 |
+
])
|
| 610 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 611 |
+
|
| 612 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 613 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 614 |
+
for block in pipe.dit.blocks:
|
| 615 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 616 |
+
block.projector = nn.Linear(dim, dim)
|
| 617 |
+
block.cam_encoder.weight.data.zero_()
|
| 618 |
+
block.cam_encoder.bias.data.zero_()
|
| 619 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 620 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 621 |
+
|
| 622 |
+
# 3. 添加FramePack组件
|
| 623 |
+
add_framepack_components(pipe.dit)
|
| 624 |
+
|
| 625 |
+
# 4. 添加MoE组件
|
| 626 |
+
moe_config = {
|
| 627 |
+
"num_experts": moe_num_experts,
|
| 628 |
+
"top_k": moe_top_k,
|
| 629 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 630 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 631 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 632 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 633 |
+
}
|
| 634 |
+
add_moe_components(pipe.dit, moe_config)
|
| 635 |
+
|
| 636 |
+
# 5. 加载训练好的权重
|
| 637 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 638 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 639 |
+
pipe = pipe.to(device)
|
| 640 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 641 |
+
|
| 642 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 643 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 644 |
+
|
| 645 |
+
pipe.scheduler.set_timesteps(50)
|
| 646 |
+
|
| 647 |
+
# 6. 加载初始条件
|
| 648 |
+
print("Loading initial condition frames...")
|
| 649 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 650 |
+
condition_pth_path,
|
| 651 |
+
start_frame=start_frame,
|
| 652 |
+
num_frames=initial_condition_frames
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# 空间裁剪
|
| 656 |
+
target_height, target_width = 60, 104
|
| 657 |
+
C, T, H, W = initial_latents.shape
|
| 658 |
+
|
| 659 |
+
if H > target_height or W > target_width:
|
| 660 |
+
h_start = (H - target_height) // 2
|
| 661 |
+
w_start = (W - target_width) // 2
|
| 662 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 663 |
+
H, W = target_height, target_width
|
| 664 |
+
|
| 665 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 666 |
+
|
| 667 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 668 |
+
|
| 669 |
+
# 7. 编码prompt - 支持CFG
|
| 670 |
+
if text_guidance_scale > 1.0:
|
| 671 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 672 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 673 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 674 |
+
else:
|
| 675 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 676 |
+
prompt_emb_neg = None
|
| 677 |
+
print("不使用Text CFG")
|
| 678 |
+
|
| 679 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 680 |
+
scene_info = None
|
| 681 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 682 |
+
with open(scene_info_path, 'r') as f:
|
| 683 |
+
scene_info = json.load(f)
|
| 684 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 685 |
+
|
| 686 |
+
# 9. 预生成完整的camera embedding序列
|
| 687 |
+
if modality_type == "sekai":
|
| 688 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 689 |
+
encoded_data.get('cam_emb', None),
|
| 690 |
+
0,
|
| 691 |
+
max_history_frames,
|
| 692 |
+
0,
|
| 693 |
+
0,
|
| 694 |
+
use_real_poses=use_real_poses
|
| 695 |
+
).to(device, dtype=model_dtype)
|
| 696 |
+
elif modality_type == "nuscenes":
|
| 697 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 698 |
+
scene_info,
|
| 699 |
+
0,
|
| 700 |
+
max_history_frames,
|
| 701 |
+
0
|
| 702 |
+
).to(device, dtype=model_dtype)
|
| 703 |
+
elif modality_type == "openx":
|
| 704 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 705 |
+
encoded_data,
|
| 706 |
+
0,
|
| 707 |
+
max_history_frames,
|
| 708 |
+
0,
|
| 709 |
+
use_real_poses=use_real_poses
|
| 710 |
+
).to(device, dtype=model_dtype)
|
| 711 |
+
else:
|
| 712 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 713 |
+
|
| 714 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 715 |
+
|
| 716 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 717 |
+
if use_camera_cfg:
|
| 718 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 719 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 720 |
+
|
| 721 |
+
# 11. 滑动窗口生成循环
|
| 722 |
+
total_generated = 0
|
| 723 |
+
all_generated_frames = []
|
| 724 |
+
|
| 725 |
+
while total_generated < total_frames_to_generate:
|
| 726 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 727 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 728 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 729 |
+
|
| 730 |
+
# FramePack数据准备 - MoE版本
|
| 731 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 732 |
+
history_latents,
|
| 733 |
+
current_generation,
|
| 734 |
+
camera_embedding_full,
|
| 735 |
+
start_frame,
|
| 736 |
+
modality_type,
|
| 737 |
+
max_history_frames
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# 准备输入
|
| 741 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 742 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 743 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 744 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 745 |
+
|
| 746 |
+
# 准备modality_inputs
|
| 747 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 748 |
+
|
| 749 |
+
# 为CFG准备无条件camera embedding
|
| 750 |
+
if use_camera_cfg:
|
| 751 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 752 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 753 |
+
|
| 754 |
+
# 索引处理
|
| 755 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 756 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 757 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 758 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 759 |
+
|
| 760 |
+
# 初始化要生成的latents
|
| 761 |
+
new_latents = torch.randn(
|
| 762 |
+
1, C, current_generation, H, W,
|
| 763 |
+
device=device, dtype=model_dtype
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 767 |
+
|
| 768 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 769 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 770 |
+
|
| 771 |
+
# 去噪循环 - 支持CFG
|
| 772 |
+
timesteps = pipe.scheduler.timesteps
|
| 773 |
+
|
| 774 |
+
for i, timestep in enumerate(timesteps):
|
| 775 |
+
if i % 10 == 0:
|
| 776 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 777 |
+
|
| 778 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 779 |
+
|
| 780 |
+
with torch.no_grad():
|
| 781 |
+
# CFG推理
|
| 782 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 783 |
+
# 条件预测(有camera)
|
| 784 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 785 |
+
new_latents,
|
| 786 |
+
timestep=timestep_tensor,
|
| 787 |
+
cam_emb=camera_embedding,
|
| 788 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 789 |
+
latent_indices=latent_indices,
|
| 790 |
+
clean_latents=clean_latents,
|
| 791 |
+
clean_latent_indices=clean_latent_indices,
|
| 792 |
+
clean_latents_2x=clean_latents_2x,
|
| 793 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 794 |
+
clean_latents_4x=clean_latents_4x,
|
| 795 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 796 |
+
**prompt_emb_pos,
|
| 797 |
+
**extra_input
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
# 无条件预测(无camera)
|
| 801 |
+
noise_pred_uncond, moe_loess = pipe.dit(
|
| 802 |
+
new_latents,
|
| 803 |
+
timestep=timestep_tensor,
|
| 804 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 805 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 806 |
+
latent_indices=latent_indices,
|
| 807 |
+
clean_latents=clean_latents,
|
| 808 |
+
clean_latent_indices=clean_latent_indices,
|
| 809 |
+
clean_latents_2x=clean_latents_2x,
|
| 810 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 811 |
+
clean_latents_4x=clean_latents_4x,
|
| 812 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 813 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 814 |
+
**extra_input
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# Camera CFG
|
| 818 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 819 |
+
|
| 820 |
+
# 如果同时使用Text CFG
|
| 821 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 822 |
+
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 823 |
+
new_latents,
|
| 824 |
+
timestep=timestep_tensor,
|
| 825 |
+
cam_emb=camera_embedding,
|
| 826 |
+
modality_inputs=modality_inputs,
|
| 827 |
+
latent_indices=latent_indices,
|
| 828 |
+
clean_latents=clean_latents,
|
| 829 |
+
clean_latent_indices=clean_latent_indices,
|
| 830 |
+
clean_latents_2x=clean_latents_2x,
|
| 831 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 832 |
+
clean_latents_4x=clean_latents_4x,
|
| 833 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 834 |
+
**prompt_emb_neg,
|
| 835 |
+
**extra_input
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 839 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 840 |
+
|
| 841 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 842 |
+
# 只使用Text CFG
|
| 843 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 844 |
+
new_latents,
|
| 845 |
+
timestep=timestep_tensor,
|
| 846 |
+
cam_emb=camera_embedding,
|
| 847 |
+
modality_inputs=modality_inputs,
|
| 848 |
+
latent_indices=latent_indices,
|
| 849 |
+
clean_latents=clean_latents,
|
| 850 |
+
clean_latent_indices=clean_latent_indices,
|
| 851 |
+
clean_latents_2x=clean_latents_2x,
|
| 852 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 853 |
+
clean_latents_4x=clean_latents_4x,
|
| 854 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 855 |
+
**prompt_emb_pos,
|
| 856 |
+
**extra_input
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
noise_pred_uncond, moe_loess= pipe.dit(
|
| 860 |
+
new_latents,
|
| 861 |
+
timestep=timestep_tensor,
|
| 862 |
+
cam_emb=camera_embedding,
|
| 863 |
+
modality_inputs=modality_inputs,
|
| 864 |
+
latent_indices=latent_indices,
|
| 865 |
+
clean_latents=clean_latents,
|
| 866 |
+
clean_latent_indices=clean_latent_indices,
|
| 867 |
+
clean_latents_2x=clean_latents_2x,
|
| 868 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 869 |
+
clean_latents_4x=clean_latents_4x,
|
| 870 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 871 |
+
**prompt_emb_neg,
|
| 872 |
+
**extra_input
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 876 |
+
|
| 877 |
+
else:
|
| 878 |
+
# 标准推理(无CFG)
|
| 879 |
+
noise_pred, moe_loess = pipe.dit(
|
| 880 |
+
new_latents,
|
| 881 |
+
timestep=timestep_tensor,
|
| 882 |
+
cam_emb=camera_embedding,
|
| 883 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 884 |
+
latent_indices=latent_indices,
|
| 885 |
+
clean_latents=clean_latents,
|
| 886 |
+
clean_latent_indices=clean_latent_indices,
|
| 887 |
+
clean_latents_2x=clean_latents_2x,
|
| 888 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 889 |
+
clean_latents_4x=clean_latents_4x,
|
| 890 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 891 |
+
**prompt_emb_pos,
|
| 892 |
+
**extra_input
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 896 |
+
|
| 897 |
+
# 更新历史
|
| 898 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 899 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 900 |
+
|
| 901 |
+
# 维护滑动窗口
|
| 902 |
+
if history_latents.shape[1] > max_history_frames:
|
| 903 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 904 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 905 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 906 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 907 |
+
|
| 908 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 909 |
+
|
| 910 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 911 |
+
total_generated += current_generation
|
| 912 |
+
|
| 913 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 914 |
+
|
| 915 |
+
# 12. 解码和保存
|
| 916 |
+
print("\n🔧 解码生成的视频...")
|
| 917 |
+
|
| 918 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 919 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 920 |
+
|
| 921 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 922 |
+
|
| 923 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 924 |
+
|
| 925 |
+
print(f"Saving video to {output_path}")
|
| 926 |
+
|
| 927 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 928 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 929 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 930 |
+
|
| 931 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 932 |
+
for frame in video_np:
|
| 933 |
+
writer.append_data(frame)
|
| 934 |
+
|
| 935 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 936 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 937 |
+
print(f"使用模态: {modality_type}")
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def main():
|
| 941 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 942 |
+
|
| 943 |
+
# 基础参数
|
| 944 |
+
parser.add_argument("--condition_pth", type=str,
|
| 945 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 946 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 947 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 948 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 949 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 950 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 951 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 952 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 953 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 954 |
+
parser.add_argument("--use_real_poses", default=True)
|
| 955 |
+
parser.add_argument("--dit_path", type=str,
|
| 956 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt")
|
| 957 |
+
parser.add_argument("--output_path", type=str,
|
| 958 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 959 |
+
parser.add_argument("--prompt", type=str,
|
| 960 |
+
default="A drone flying scene in a game world ")
|
| 961 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 962 |
+
|
| 963 |
+
# 模态类型参数
|
| 964 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 965 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 966 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 967 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 968 |
+
|
| 969 |
+
# CFG参数
|
| 970 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 971 |
+
help="使用Camera CFG")
|
| 972 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 973 |
+
help="Camera guidance scale for CFG")
|
| 974 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 975 |
+
help="Text guidance scale for CFG")
|
| 976 |
+
|
| 977 |
+
# MoE参数
|
| 978 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 979 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 980 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 981 |
+
|
| 982 |
+
args = parser.parse_args()
|
| 983 |
+
|
| 984 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 985 |
+
print(f"模态类型: {args.modality_type}")
|
| 986 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 987 |
+
if args.use_camera_cfg:
|
| 988 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 989 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 990 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 991 |
+
print(f"DiT{args.dit_path}")
|
| 992 |
+
|
| 993 |
+
# 验证NuScenes参数
|
| 994 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 995 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 996 |
+
|
| 997 |
+
inference_moe_framepack_sliding_window(
|
| 998 |
+
condition_pth_path=args.condition_pth,
|
| 999 |
+
dit_path=args.dit_path,
|
| 1000 |
+
output_path=args.output_path,
|
| 1001 |
+
start_frame=args.start_frame,
|
| 1002 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 1003 |
+
frames_per_generation=args.frames_per_generation,
|
| 1004 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 1005 |
+
max_history_frames=args.max_history_frames,
|
| 1006 |
+
device=args.device,
|
| 1007 |
+
prompt=args.prompt,
|
| 1008 |
+
modality_type=args.modality_type,
|
| 1009 |
+
use_real_poses=args.use_real_poses,
|
| 1010 |
+
scene_info_path=args.scene_info_path,
|
| 1011 |
+
# CFG参数
|
| 1012 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 1013 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 1014 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1015 |
+
# MoE参数
|
| 1016 |
+
moe_num_experts=args.moe_num_experts,
|
| 1017 |
+
moe_top_k=args.moe_top_k,
|
| 1018 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
if __name__ == "__main__":
|
| 1023 |
+
main()
|
scripts/infer_moe_spatialvid.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
from scipy.spatial.transform import Rotation as R
|
| 14 |
+
|
| 15 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 16 |
+
"""
|
| 17 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 18 |
+
|
| 19 |
+
参数:
|
| 20 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 21 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 22 |
+
|
| 23 |
+
返回:
|
| 24 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 25 |
+
"""
|
| 26 |
+
# 分离平移向量和四元数
|
| 27 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 28 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 29 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 30 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 31 |
+
|
| 32 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 33 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 34 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 35 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 36 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 37 |
+
|
| 38 |
+
# 2. 计算相对平移向量 t_rel
|
| 39 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 40 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 41 |
+
|
| 42 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 43 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 44 |
+
|
| 45 |
+
return relative_matrix
|
| 46 |
+
|
| 47 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 48 |
+
"""从pth文件加载预编码的视频数据"""
|
| 49 |
+
print(f"Loading encoded video from {pth_path}")
|
| 50 |
+
|
| 51 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 52 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 53 |
+
|
| 54 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 55 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 56 |
+
|
| 57 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 58 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 59 |
+
|
| 60 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 61 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 62 |
+
|
| 63 |
+
return condition_latents, encoded_data
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 67 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 68 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 69 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 70 |
+
|
| 71 |
+
if use_torch:
|
| 72 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 73 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 74 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 75 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 76 |
+
|
| 77 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 78 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 79 |
+
else:
|
| 80 |
+
if not isinstance(pose_a, np.ndarray):
|
| 81 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 82 |
+
if not isinstance(pose_b, np.ndarray):
|
| 83 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 84 |
+
|
| 85 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 86 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 87 |
+
|
| 88 |
+
return relative_pose
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def replace_dit_model_in_manager():
|
| 92 |
+
"""替换DiT模型类为MoE版本"""
|
| 93 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 94 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 95 |
+
|
| 96 |
+
for i, config in enumerate(model_loader_configs):
|
| 97 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 98 |
+
|
| 99 |
+
if 'wan_video_dit' in model_names:
|
| 100 |
+
new_model_names = []
|
| 101 |
+
new_model_classes = []
|
| 102 |
+
|
| 103 |
+
for name, cls in zip(model_names, model_classes):
|
| 104 |
+
if name == 'wan_video_dit':
|
| 105 |
+
new_model_names.append(name)
|
| 106 |
+
new_model_classes.append(WanModelMoe)
|
| 107 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 108 |
+
else:
|
| 109 |
+
new_model_names.append(name)
|
| 110 |
+
new_model_classes.append(cls)
|
| 111 |
+
|
| 112 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def add_framepack_components(dit_model):
|
| 116 |
+
"""添加FramePack相关组件"""
|
| 117 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 118 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 119 |
+
|
| 120 |
+
class CleanXEmbedder(nn.Module):
|
| 121 |
+
def __init__(self, inner_dim):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 124 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 125 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 126 |
+
|
| 127 |
+
def forward(self, x, scale="1x"):
|
| 128 |
+
if scale == "1x":
|
| 129 |
+
x = x.to(self.proj.weight.dtype)
|
| 130 |
+
return self.proj(x)
|
| 131 |
+
elif scale == "2x":
|
| 132 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 133 |
+
return self.proj_2x(x)
|
| 134 |
+
elif scale == "4x":
|
| 135 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 136 |
+
return self.proj_4x(x)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 139 |
+
|
| 140 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 141 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 142 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 143 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def add_moe_components(dit_model, moe_config):
|
| 147 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 148 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 149 |
+
dit_model.moe_config = moe_config
|
| 150 |
+
print("✅ 添加了MoE配置到模型")
|
| 151 |
+
|
| 152 |
+
# 为每个block动态添加MoE组件
|
| 153 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 154 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 155 |
+
|
| 156 |
+
for i, block in enumerate(dit_model.blocks):
|
| 157 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 158 |
+
|
| 159 |
+
# Sekai模态处理器 - 输出unified_dim
|
| 160 |
+
block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 161 |
+
|
| 162 |
+
# # NuScenes模态处理器 - 输出unified_dim
|
| 163 |
+
# block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 164 |
+
|
| 165 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 166 |
+
block.moe = MultiModalMoE(
|
| 167 |
+
unified_dim=unified_dim,
|
| 168 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 169 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 170 |
+
top_k=moe_config.get("top_k", 2)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 177 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 178 |
+
time_compression_ratio = 4
|
| 179 |
+
|
| 180 |
+
# 计算FramePack实际需要的camera帧数
|
| 181 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 182 |
+
|
| 183 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 184 |
+
print("🔧 使用真实Sekai camera数据")
|
| 185 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 186 |
+
|
| 187 |
+
# 确保生成足够长的camera序列
|
| 188 |
+
max_needed_frames = max(
|
| 189 |
+
start_frame + current_history_length + new_frames,
|
| 190 |
+
framepack_needed_frames,
|
| 191 |
+
30
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 195 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 196 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 197 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 198 |
+
|
| 199 |
+
relative_poses = []
|
| 200 |
+
for i in range(max_needed_frames):
|
| 201 |
+
# 计算当前帧在原始序列中的位置
|
| 202 |
+
frame_idx = i * time_compression_ratio
|
| 203 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 204 |
+
|
| 205 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 206 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 207 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 208 |
+
relative_pose = compute_relative_pose_matrix(cam_prev, cam_next)
|
| 209 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 210 |
+
else:
|
| 211 |
+
# 超出范围,使用零运动
|
| 212 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 213 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 214 |
+
|
| 215 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 216 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 217 |
+
|
| 218 |
+
# 创建对应长度的mask序列
|
| 219 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 220 |
+
# 从start_frame到current_history_length标记为condition
|
| 221 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 222 |
+
mask[start_frame:condition_end] = 1.0
|
| 223 |
+
|
| 224 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 225 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 226 |
+
return camera_embedding.to(torch.bfloat16)
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
print("🔧 使用Sekai合成camera数据")
|
| 230 |
+
|
| 231 |
+
max_needed_frames = max(
|
| 232 |
+
start_frame + current_history_length + new_frames,
|
| 233 |
+
framepack_needed_frames,
|
| 234 |
+
30
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 238 |
+
relative_poses = []
|
| 239 |
+
for i in range(max_needed_frames):
|
| 240 |
+
# 持续左转运动模式
|
| 241 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 242 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 243 |
+
|
| 244 |
+
pose = np.eye(4, dtype=np.float32)
|
| 245 |
+
|
| 246 |
+
# 旋转矩阵(绕Y轴左转)
|
| 247 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 248 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 249 |
+
|
| 250 |
+
pose[0, 0] = cos_yaw
|
| 251 |
+
pose[0, 2] = sin_yaw
|
| 252 |
+
pose[2, 0] = -sin_yaw
|
| 253 |
+
pose[2, 2] = cos_yaw
|
| 254 |
+
|
| 255 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 256 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 257 |
+
|
| 258 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 259 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 260 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 261 |
+
|
| 262 |
+
relative_pose = pose[:3, :]
|
| 263 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 264 |
+
|
| 265 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 266 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 267 |
+
|
| 268 |
+
# 创建对应长度的mask序列
|
| 269 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 270 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 271 |
+
mask[start_frame:condition_end] = 1.0
|
| 272 |
+
|
| 273 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 274 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 275 |
+
return camera_embedding.to(torch.bfloat16)
|
| 276 |
+
|
| 277 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 278 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 279 |
+
time_compression_ratio = 4
|
| 280 |
+
|
| 281 |
+
# 计算FramePack实际需要的camera帧数
|
| 282 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 283 |
+
|
| 284 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 285 |
+
print("🔧 使用OpenX真实camera数据")
|
| 286 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 287 |
+
|
| 288 |
+
# 确保生成足够长的camera序列
|
| 289 |
+
max_needed_frames = max(
|
| 290 |
+
start_frame + current_history_length + new_frames,
|
| 291 |
+
framepack_needed_frames,
|
| 292 |
+
30
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 296 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 297 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 298 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 299 |
+
|
| 300 |
+
relative_poses = []
|
| 301 |
+
for i in range(max_needed_frames):
|
| 302 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 303 |
+
frame_idx = i * time_compression_ratio
|
| 304 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 305 |
+
|
| 306 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 307 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 308 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 309 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 310 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 311 |
+
else:
|
| 312 |
+
# 超出范围,使用零运动
|
| 313 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 314 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 315 |
+
|
| 316 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 317 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 318 |
+
|
| 319 |
+
# 创建对应长度的mask序列
|
| 320 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 321 |
+
# 从start_frame到current_history_length标记为condition
|
| 322 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 323 |
+
mask[start_frame:condition_end] = 1.0
|
| 324 |
+
|
| 325 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 326 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 327 |
+
return camera_embedding.to(torch.bfloat16)
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
print("🔧 使用OpenX合成camera数据")
|
| 331 |
+
|
| 332 |
+
max_needed_frames = max(
|
| 333 |
+
start_frame + current_history_length + new_frames,
|
| 334 |
+
framepack_needed_frames,
|
| 335 |
+
30
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 339 |
+
relative_poses = []
|
| 340 |
+
for i in range(max_needed_frames):
|
| 341 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 342 |
+
# 模拟机器人手臂的精细操作运动
|
| 343 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 344 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 345 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 346 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 347 |
+
|
| 348 |
+
pose = np.eye(4, dtype=np.float32)
|
| 349 |
+
|
| 350 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 351 |
+
# 绕X轴旋转(roll)
|
| 352 |
+
cos_roll = np.cos(roll_per_frame)
|
| 353 |
+
sin_roll = np.sin(roll_per_frame)
|
| 354 |
+
# 绕Y轴旋转(pitch)
|
| 355 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 356 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 357 |
+
# 绕Z轴旋转(yaw)
|
| 358 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 359 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 360 |
+
|
| 361 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 362 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 363 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 364 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 365 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 366 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 367 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 368 |
+
pose[2, 0] = -sin_pitch
|
| 369 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 370 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 371 |
+
|
| 372 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 373 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 374 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 375 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 376 |
+
|
| 377 |
+
relative_pose = pose[:3, :]
|
| 378 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 379 |
+
|
| 380 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 381 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 382 |
+
|
| 383 |
+
# 创建对应长度的mask序列
|
| 384 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 385 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 386 |
+
mask[start_frame:condition_end] = 1.0
|
| 387 |
+
|
| 388 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 389 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 390 |
+
return camera_embedding.to(torch.bfloat16)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 394 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 395 |
+
time_compression_ratio = 4
|
| 396 |
+
|
| 397 |
+
# 计算FramePack实际需要的camera帧数
|
| 398 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 399 |
+
|
| 400 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 401 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 402 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 403 |
+
|
| 404 |
+
if len(keyframe_poses) == 0:
|
| 405 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 406 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 407 |
+
|
| 408 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 409 |
+
|
| 410 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 411 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 412 |
+
mask[start_frame:condition_end] = 1.0
|
| 413 |
+
|
| 414 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 415 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 416 |
+
return camera_embedding.to(torch.bfloat16)
|
| 417 |
+
|
| 418 |
+
# 使用第一个pose作为参考
|
| 419 |
+
reference_pose = keyframe_poses[0]
|
| 420 |
+
|
| 421 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 422 |
+
|
| 423 |
+
pose_vecs = []
|
| 424 |
+
for i in range(max_needed_frames):
|
| 425 |
+
if i < len(keyframe_poses):
|
| 426 |
+
current_pose = keyframe_poses[i]
|
| 427 |
+
|
| 428 |
+
# 计算相对位移
|
| 429 |
+
translation = torch.tensor(
|
| 430 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 431 |
+
dtype=torch.float32
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# 计算相对旋转(简化版本)
|
| 435 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 436 |
+
|
| 437 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 438 |
+
else:
|
| 439 |
+
# 超出范围,使用零pose
|
| 440 |
+
pose_vec = torch.cat([
|
| 441 |
+
torch.zeros(3, dtype=torch.float32),
|
| 442 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 443 |
+
], dim=0) # [7D]
|
| 444 |
+
|
| 445 |
+
pose_vecs.append(pose_vec)
|
| 446 |
+
|
| 447 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 448 |
+
|
| 449 |
+
# 创建mask
|
| 450 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 451 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 452 |
+
mask[start_frame:condition_end] = 1.0
|
| 453 |
+
|
| 454 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 455 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 456 |
+
return camera_embedding.to(torch.bfloat16)
|
| 457 |
+
|
| 458 |
+
else:
|
| 459 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 460 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 461 |
+
|
| 462 |
+
# 创建合成运动序列
|
| 463 |
+
pose_vecs = []
|
| 464 |
+
for i in range(max_needed_frames):
|
| 465 |
+
# 简单的前进运动
|
| 466 |
+
translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
|
| 467 |
+
rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
|
| 468 |
+
|
| 469 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 470 |
+
pose_vecs.append(pose_vec)
|
| 471 |
+
|
| 472 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 473 |
+
|
| 474 |
+
# 创建mask
|
| 475 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 476 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 477 |
+
mask[start_frame:condition_end] = 1.0
|
| 478 |
+
|
| 479 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 480 |
+
print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
|
| 481 |
+
return camera_embedding.to(torch.bfloat16)
|
| 482 |
+
|
| 483 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 484 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 485 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 486 |
+
C, T, H, W = history_latents.shape
|
| 487 |
+
|
| 488 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 489 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 490 |
+
indices = torch.arange(0, total_indices_length)
|
| 491 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 492 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 493 |
+
indices.split(split_sizes, dim=0)
|
| 494 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 495 |
+
|
| 496 |
+
# 检查camera长度是否足够
|
| 497 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 498 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 499 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 500 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 501 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 502 |
+
|
| 503 |
+
# 从完整camera序列中选取对应部分
|
| 504 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 505 |
+
|
| 506 |
+
# 根据当前history length重新设置mask
|
| 507 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 508 |
+
|
| 509 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 510 |
+
if T > 0:
|
| 511 |
+
available_frames = min(T, 19)
|
| 512 |
+
start_pos = 19 - available_frames
|
| 513 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 514 |
+
|
| 515 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 516 |
+
print(f" - 历史帧数: {T}")
|
| 517 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 518 |
+
print(f" - 模态类型: {modality_type}")
|
| 519 |
+
|
| 520 |
+
# 处理latents
|
| 521 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 522 |
+
|
| 523 |
+
if T > 0:
|
| 524 |
+
available_frames = min(T, 19)
|
| 525 |
+
start_pos = 19 - available_frames
|
| 526 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 527 |
+
|
| 528 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 529 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 530 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 531 |
+
|
| 532 |
+
if T > 0:
|
| 533 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 534 |
+
else:
|
| 535 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 536 |
+
|
| 537 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 538 |
+
|
| 539 |
+
return {
|
| 540 |
+
'latent_indices': latent_indices,
|
| 541 |
+
'clean_latents': clean_latents,
|
| 542 |
+
'clean_latents_2x': clean_latents_2x,
|
| 543 |
+
'clean_latents_4x': clean_latents_4x,
|
| 544 |
+
'clean_latent_indices': clean_latent_indices,
|
| 545 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 546 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 547 |
+
'camera_embedding': combined_camera,
|
| 548 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 549 |
+
'current_length': T,
|
| 550 |
+
'next_length': T + target_frames_to_generate
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def inference_moe_framepack_sliding_window(
|
| 555 |
+
condition_pth_path,
|
| 556 |
+
dit_path,
|
| 557 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 558 |
+
start_frame=0,
|
| 559 |
+
initial_condition_frames=8,
|
| 560 |
+
frames_per_generation=4,
|
| 561 |
+
total_frames_to_generate=32,
|
| 562 |
+
max_history_frames=49,
|
| 563 |
+
device="cuda",
|
| 564 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 565 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 566 |
+
use_real_poses=True,
|
| 567 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 568 |
+
# CFG参数
|
| 569 |
+
use_camera_cfg=True,
|
| 570 |
+
camera_guidance_scale=2.0,
|
| 571 |
+
text_guidance_scale=1.0,
|
| 572 |
+
# MoE参数
|
| 573 |
+
moe_num_experts=4,
|
| 574 |
+
moe_top_k=2,
|
| 575 |
+
moe_hidden_dim=None
|
| 576 |
+
):
|
| 577 |
+
"""
|
| 578 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 579 |
+
"""
|
| 580 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 581 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 582 |
+
print(f"模态类型: {modality_type}")
|
| 583 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 584 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 585 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 586 |
+
|
| 587 |
+
# 1. 模型初始化
|
| 588 |
+
replace_dit_model_in_manager()
|
| 589 |
+
|
| 590 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 591 |
+
model_manager.load_models([
|
| 592 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 593 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 594 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 595 |
+
])
|
| 596 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 597 |
+
|
| 598 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 599 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 600 |
+
for block in pipe.dit.blocks:
|
| 601 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 602 |
+
block.projector = nn.Linear(dim, dim)
|
| 603 |
+
block.cam_encoder.weight.data.zero_()
|
| 604 |
+
block.cam_encoder.bias.data.zero_()
|
| 605 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 606 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 607 |
+
|
| 608 |
+
# 3. 添加FramePack组件
|
| 609 |
+
add_framepack_components(pipe.dit)
|
| 610 |
+
|
| 611 |
+
# 4. 添加MoE组件
|
| 612 |
+
moe_config = {
|
| 613 |
+
"num_experts": moe_num_experts,
|
| 614 |
+
"top_k": moe_top_k,
|
| 615 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 616 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 617 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 618 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 619 |
+
}
|
| 620 |
+
add_moe_components(pipe.dit, moe_config)
|
| 621 |
+
|
| 622 |
+
# 5. 加载训练好的权重
|
| 623 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 624 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 625 |
+
pipe = pipe.to(device)
|
| 626 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 627 |
+
|
| 628 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 629 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 630 |
+
|
| 631 |
+
pipe.scheduler.set_timesteps(50)
|
| 632 |
+
|
| 633 |
+
# 6. 加载初始条件
|
| 634 |
+
print("Loading initial condition frames...")
|
| 635 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 636 |
+
condition_pth_path,
|
| 637 |
+
start_frame=start_frame,
|
| 638 |
+
num_frames=initial_condition_frames
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# 空间裁剪
|
| 642 |
+
target_height, target_width = 60, 104
|
| 643 |
+
C, T, H, W = initial_latents.shape
|
| 644 |
+
|
| 645 |
+
if H > target_height or W > target_width:
|
| 646 |
+
h_start = (H - target_height) // 2
|
| 647 |
+
w_start = (W - target_width) // 2
|
| 648 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 649 |
+
H, W = target_height, target_width
|
| 650 |
+
|
| 651 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 652 |
+
|
| 653 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 654 |
+
|
| 655 |
+
# 7. 编码prompt - 支持CFG
|
| 656 |
+
if text_guidance_scale > 1.0:
|
| 657 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 658 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 659 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 660 |
+
else:
|
| 661 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 662 |
+
prompt_emb_neg = None
|
| 663 |
+
print("不使用Text CFG")
|
| 664 |
+
|
| 665 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 666 |
+
scene_info = None
|
| 667 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 668 |
+
with open(scene_info_path, 'r') as f:
|
| 669 |
+
scene_info = json.load(f)
|
| 670 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 671 |
+
|
| 672 |
+
# 9. 预生成完整的camera embedding序列
|
| 673 |
+
if modality_type == "sekai":
|
| 674 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 675 |
+
encoded_data.get('cam_emb', None),
|
| 676 |
+
0,
|
| 677 |
+
max_history_frames,
|
| 678 |
+
0,
|
| 679 |
+
0,
|
| 680 |
+
use_real_poses=use_real_poses
|
| 681 |
+
).to(device, dtype=model_dtype)
|
| 682 |
+
elif modality_type == "nuscenes":
|
| 683 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 684 |
+
scene_info,
|
| 685 |
+
0,
|
| 686 |
+
max_history_frames,
|
| 687 |
+
0
|
| 688 |
+
).to(device, dtype=model_dtype)
|
| 689 |
+
elif modality_type == "openx":
|
| 690 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 691 |
+
encoded_data,
|
| 692 |
+
0,
|
| 693 |
+
max_history_frames,
|
| 694 |
+
0,
|
| 695 |
+
use_real_poses=use_real_poses
|
| 696 |
+
).to(device, dtype=model_dtype)
|
| 697 |
+
else:
|
| 698 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 699 |
+
|
| 700 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 701 |
+
|
| 702 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 703 |
+
if use_camera_cfg:
|
| 704 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 705 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 706 |
+
|
| 707 |
+
# 11. 滑动窗口生成循环
|
| 708 |
+
total_generated = 0
|
| 709 |
+
all_generated_frames = []
|
| 710 |
+
|
| 711 |
+
while total_generated < total_frames_to_generate:
|
| 712 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 713 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 714 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 715 |
+
|
| 716 |
+
# FramePack数据准备 - MoE版本
|
| 717 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 718 |
+
history_latents,
|
| 719 |
+
current_generation,
|
| 720 |
+
camera_embedding_full,
|
| 721 |
+
start_frame,
|
| 722 |
+
modality_type,
|
| 723 |
+
max_history_frames
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# 准备输入
|
| 727 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 728 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 729 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 730 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 731 |
+
|
| 732 |
+
# 准备modality_inputs
|
| 733 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 734 |
+
|
| 735 |
+
# 为CFG准备无条件camera embedding
|
| 736 |
+
if use_camera_cfg:
|
| 737 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 738 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 739 |
+
|
| 740 |
+
# 索引处理
|
| 741 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 742 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 743 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 744 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 745 |
+
|
| 746 |
+
# 初始化要生成的latents
|
| 747 |
+
new_latents = torch.randn(
|
| 748 |
+
1, C, current_generation, H, W,
|
| 749 |
+
device=device, dtype=model_dtype
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 753 |
+
|
| 754 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 755 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 756 |
+
|
| 757 |
+
# 去噪循环 - 支持CFG
|
| 758 |
+
timesteps = pipe.scheduler.timesteps
|
| 759 |
+
|
| 760 |
+
for i, timestep in enumerate(timesteps):
|
| 761 |
+
if i % 10 == 0:
|
| 762 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 763 |
+
|
| 764 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 765 |
+
|
| 766 |
+
with torch.no_grad():
|
| 767 |
+
# CFG推理
|
| 768 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 769 |
+
# 条件预测(有camera)
|
| 770 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 771 |
+
new_latents,
|
| 772 |
+
timestep=timestep_tensor,
|
| 773 |
+
cam_emb=camera_embedding,
|
| 774 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 775 |
+
latent_indices=latent_indices,
|
| 776 |
+
clean_latents=clean_latents,
|
| 777 |
+
clean_latent_indices=clean_latent_indices,
|
| 778 |
+
clean_latents_2x=clean_latents_2x,
|
| 779 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 780 |
+
clean_latents_4x=clean_latents_4x,
|
| 781 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 782 |
+
**prompt_emb_pos,
|
| 783 |
+
**extra_input
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# 无条件预测(无camera)
|
| 787 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 788 |
+
new_latents,
|
| 789 |
+
timestep=timestep_tensor,
|
| 790 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 791 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 792 |
+
latent_indices=latent_indices,
|
| 793 |
+
clean_latents=clean_latents,
|
| 794 |
+
clean_latent_indices=clean_latent_indices,
|
| 795 |
+
clean_latents_2x=clean_latents_2x,
|
| 796 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 797 |
+
clean_latents_4x=clean_latents_4x,
|
| 798 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 799 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 800 |
+
**extra_input
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# Camera CFG
|
| 804 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 805 |
+
|
| 806 |
+
# 如果同时使用Text CFG
|
| 807 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 808 |
+
noise_pred_text_uncond, moe_loss = pipe.dit(
|
| 809 |
+
new_latents,
|
| 810 |
+
timestep=timestep_tensor,
|
| 811 |
+
cam_emb=camera_embedding,
|
| 812 |
+
modality_inputs=modality_inputs,
|
| 813 |
+
latent_indices=latent_indices,
|
| 814 |
+
clean_latents=clean_latents,
|
| 815 |
+
clean_latent_indices=clean_latent_indices,
|
| 816 |
+
clean_latents_2x=clean_latents_2x,
|
| 817 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 818 |
+
clean_latents_4x=clean_latents_4x,
|
| 819 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 820 |
+
**prompt_emb_neg,
|
| 821 |
+
**extra_input
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 825 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 826 |
+
|
| 827 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 828 |
+
# 只使用Text CFG
|
| 829 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 830 |
+
new_latents,
|
| 831 |
+
timestep=timestep_tensor,
|
| 832 |
+
cam_emb=camera_embedding,
|
| 833 |
+
modality_inputs=modality_inputs,
|
| 834 |
+
latent_indices=latent_indices,
|
| 835 |
+
clean_latents=clean_latents,
|
| 836 |
+
clean_latent_indices=clean_latent_indices,
|
| 837 |
+
clean_latents_2x=clean_latents_2x,
|
| 838 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 839 |
+
clean_latents_4x=clean_latents_4x,
|
| 840 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 841 |
+
**prompt_emb_pos,
|
| 842 |
+
**extra_input
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 846 |
+
new_latents,
|
| 847 |
+
timestep=timestep_tensor,
|
| 848 |
+
cam_emb=camera_embedding,
|
| 849 |
+
modality_inputs=modality_inputs,
|
| 850 |
+
latent_indices=latent_indices,
|
| 851 |
+
clean_latents=clean_latents,
|
| 852 |
+
clean_latent_indices=clean_latent_indices,
|
| 853 |
+
clean_latents_2x=clean_latents_2x,
|
| 854 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 855 |
+
clean_latents_4x=clean_latents_4x,
|
| 856 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 857 |
+
**prompt_emb_neg,
|
| 858 |
+
**extra_input
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 862 |
+
|
| 863 |
+
else:
|
| 864 |
+
# 标准推理(无CFG)
|
| 865 |
+
noise_pred, moe_loss = pipe.dit(
|
| 866 |
+
new_latents,
|
| 867 |
+
timestep=timestep_tensor,
|
| 868 |
+
cam_emb=camera_embedding,
|
| 869 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 870 |
+
latent_indices=latent_indices,
|
| 871 |
+
clean_latents=clean_latents,
|
| 872 |
+
clean_latent_indices=clean_latent_indices,
|
| 873 |
+
clean_latents_2x=clean_latents_2x,
|
| 874 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 875 |
+
clean_latents_4x=clean_latents_4x,
|
| 876 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 877 |
+
**prompt_emb_pos,
|
| 878 |
+
**extra_input
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 882 |
+
|
| 883 |
+
# 更新历史
|
| 884 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 885 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 886 |
+
|
| 887 |
+
# 维护滑动窗口
|
| 888 |
+
if history_latents.shape[1] > max_history_frames:
|
| 889 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 890 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 891 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 892 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 893 |
+
|
| 894 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 895 |
+
|
| 896 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 897 |
+
total_generated += current_generation
|
| 898 |
+
|
| 899 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 900 |
+
|
| 901 |
+
# 12. 解码和保存
|
| 902 |
+
print("\n🔧 解码生成的视频...")
|
| 903 |
+
|
| 904 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 905 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 906 |
+
|
| 907 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 908 |
+
|
| 909 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 910 |
+
|
| 911 |
+
print(f"Saving video to {output_path}")
|
| 912 |
+
|
| 913 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 914 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 915 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 916 |
+
|
| 917 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 918 |
+
for frame in video_np:
|
| 919 |
+
writer.append_data(frame)
|
| 920 |
+
|
| 921 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 922 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 923 |
+
print(f"使用模态: {modality_type}")
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def main():
|
| 927 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 928 |
+
|
| 929 |
+
# 基础参数
|
| 930 |
+
parser.add_argument("--condition_pth", type=str,
|
| 931 |
+
#default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 932 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 933 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 934 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 935 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 936 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 937 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 938 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=8)
|
| 939 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 940 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 941 |
+
parser.add_argument("--dit_path", type=str,
|
| 942 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid/step250_moe.ckpt")
|
| 943 |
+
parser.add_argument("--output_path", type=str,
|
| 944 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 945 |
+
parser.add_argument("--prompt", type=str,
|
| 946 |
+
default="A man enter the room")
|
| 947 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 948 |
+
|
| 949 |
+
# 模态类型参数
|
| 950 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 951 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 952 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 953 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 954 |
+
|
| 955 |
+
# CFG参数
|
| 956 |
+
parser.add_argument("--use_camera_cfg", default=True,
|
| 957 |
+
help="使用Camera CFG")
|
| 958 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 959 |
+
help="Camera guidance scale for CFG")
|
| 960 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 961 |
+
help="Text guidance scale for CFG")
|
| 962 |
+
|
| 963 |
+
# MoE参数
|
| 964 |
+
parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
|
| 965 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 966 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 967 |
+
|
| 968 |
+
args = parser.parse_args()
|
| 969 |
+
|
| 970 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 971 |
+
print(f"模态类型: {args.modality_type}")
|
| 972 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 973 |
+
if args.use_camera_cfg:
|
| 974 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 975 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 976 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 977 |
+
|
| 978 |
+
# 验证NuScenes参数
|
| 979 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 980 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 981 |
+
|
| 982 |
+
inference_moe_framepack_sliding_window(
|
| 983 |
+
condition_pth_path=args.condition_pth,
|
| 984 |
+
dit_path=args.dit_path,
|
| 985 |
+
output_path=args.output_path,
|
| 986 |
+
start_frame=args.start_frame,
|
| 987 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 988 |
+
frames_per_generation=args.frames_per_generation,
|
| 989 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 990 |
+
max_history_frames=args.max_history_frames,
|
| 991 |
+
device=args.device,
|
| 992 |
+
prompt=args.prompt,
|
| 993 |
+
modality_type=args.modality_type,
|
| 994 |
+
use_real_poses=args.use_real_poses,
|
| 995 |
+
scene_info_path=args.scene_info_path,
|
| 996 |
+
# CFG参数
|
| 997 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 998 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 999 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1000 |
+
# MoE参数
|
| 1001 |
+
moe_num_experts=args.moe_num_experts,
|
| 1002 |
+
moe_top_k=args.moe_top_k,
|
| 1003 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
if __name__ == "__main__":
|
| 1008 |
+
main()
|
scripts/infer_moe_test.py
ADDED
|
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 16 |
+
"""从pth文件加载预编码的视频数据"""
|
| 17 |
+
print(f"Loading encoded video from {pth_path}")
|
| 18 |
+
|
| 19 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 20 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 21 |
+
|
| 22 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 23 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 24 |
+
|
| 25 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 26 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 27 |
+
|
| 28 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 29 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 30 |
+
|
| 31 |
+
return condition_latents, encoded_data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 35 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 36 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 37 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 38 |
+
|
| 39 |
+
if use_torch:
|
| 40 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 41 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 42 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 43 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 44 |
+
|
| 45 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 46 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 47 |
+
else:
|
| 48 |
+
if not isinstance(pose_a, np.ndarray):
|
| 49 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 50 |
+
if not isinstance(pose_b, np.ndarray):
|
| 51 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 54 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 55 |
+
|
| 56 |
+
return relative_pose
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def replace_dit_model_in_manager():
|
| 60 |
+
"""替换DiT模型类为MoE版本"""
|
| 61 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 62 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 63 |
+
|
| 64 |
+
for i, config in enumerate(model_loader_configs):
|
| 65 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 66 |
+
|
| 67 |
+
if 'wan_video_dit' in model_names:
|
| 68 |
+
new_model_names = []
|
| 69 |
+
new_model_classes = []
|
| 70 |
+
|
| 71 |
+
for name, cls in zip(model_names, model_classes):
|
| 72 |
+
if name == 'wan_video_dit':
|
| 73 |
+
new_model_names.append(name)
|
| 74 |
+
new_model_classes.append(WanModelMoe)
|
| 75 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 76 |
+
else:
|
| 77 |
+
new_model_names.append(name)
|
| 78 |
+
new_model_classes.append(cls)
|
| 79 |
+
|
| 80 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def add_framepack_components(dit_model):
|
| 84 |
+
"""添加FramePack相关组件"""
|
| 85 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 86 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 87 |
+
|
| 88 |
+
class CleanXEmbedder(nn.Module):
|
| 89 |
+
def __init__(self, inner_dim):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 92 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 93 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 94 |
+
|
| 95 |
+
def forward(self, x, scale="1x"):
|
| 96 |
+
if scale == "1x":
|
| 97 |
+
x = x.to(self.proj.weight.dtype)
|
| 98 |
+
return self.proj(x)
|
| 99 |
+
elif scale == "2x":
|
| 100 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 101 |
+
return self.proj_2x(x)
|
| 102 |
+
elif scale == "4x":
|
| 103 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 104 |
+
return self.proj_4x(x)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 107 |
+
|
| 108 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 109 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 110 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 111 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def add_moe_components(dit_model, moe_config):
|
| 115 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 116 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 117 |
+
dit_model.moe_config = moe_config
|
| 118 |
+
print("✅ 添加了MoE配置到模型")
|
| 119 |
+
|
| 120 |
+
# 为每个block动态添加MoE组件
|
| 121 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 122 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 123 |
+
|
| 124 |
+
for i, block in enumerate(dit_model.blocks):
|
| 125 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 126 |
+
|
| 127 |
+
# Sekai模态处理器 - 输出unified_dim
|
| 128 |
+
block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 129 |
+
|
| 130 |
+
# # NuScenes模态处理器 - 输出unified_dim
|
| 131 |
+
# block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 132 |
+
|
| 133 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 134 |
+
block.moe = MultiModalMoE(
|
| 135 |
+
unified_dim=unified_dim,
|
| 136 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 137 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 138 |
+
top_k=moe_config.get("top_k", 2)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 145 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 146 |
+
time_compression_ratio = 4
|
| 147 |
+
|
| 148 |
+
# 计算FramePack实际需要的camera帧数
|
| 149 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 150 |
+
|
| 151 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 152 |
+
print("🔧 使用真实Sekai camera数据")
|
| 153 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 154 |
+
|
| 155 |
+
# 确保生成足够长的camera序列
|
| 156 |
+
max_needed_frames = max(
|
| 157 |
+
start_frame + current_history_length + new_frames,
|
| 158 |
+
framepack_needed_frames,
|
| 159 |
+
30
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 163 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 164 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 165 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 166 |
+
|
| 167 |
+
relative_poses = []
|
| 168 |
+
for i in range(max_needed_frames):
|
| 169 |
+
# 计算当前帧在原始序列中的位置
|
| 170 |
+
frame_idx = i * time_compression_ratio
|
| 171 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 172 |
+
|
| 173 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 174 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 175 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 176 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 177 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 178 |
+
else:
|
| 179 |
+
# 超出范围,使用零运动
|
| 180 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 181 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 182 |
+
|
| 183 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 184 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 185 |
+
|
| 186 |
+
# 创建对应长度的mask序列
|
| 187 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 188 |
+
# 从start_frame到current_history_length标记为condition
|
| 189 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 190 |
+
mask[start_frame:condition_end] = 1.0
|
| 191 |
+
|
| 192 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 193 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 194 |
+
return camera_embedding.to(torch.bfloat16)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
print("🔧 使用Sekai合成camera数据")
|
| 198 |
+
|
| 199 |
+
max_needed_frames = max(
|
| 200 |
+
start_frame + current_history_length + new_frames,
|
| 201 |
+
framepack_needed_frames,
|
| 202 |
+
30
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 206 |
+
relative_poses = []
|
| 207 |
+
for i in range(max_needed_frames):
|
| 208 |
+
# 持续左转运动模式
|
| 209 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 210 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 211 |
+
|
| 212 |
+
pose = np.eye(4, dtype=np.float32)
|
| 213 |
+
|
| 214 |
+
# 旋转矩阵(绕Y轴左转)
|
| 215 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 216 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 217 |
+
|
| 218 |
+
pose[0, 0] = cos_yaw
|
| 219 |
+
pose[0, 2] = sin_yaw
|
| 220 |
+
pose[2, 0] = -sin_yaw
|
| 221 |
+
pose[2, 2] = cos_yaw
|
| 222 |
+
|
| 223 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 224 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 225 |
+
|
| 226 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 227 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 228 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 229 |
+
|
| 230 |
+
relative_pose = pose[:3, :]
|
| 231 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 232 |
+
|
| 233 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 234 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 235 |
+
|
| 236 |
+
# 创建对应长度的mask序列
|
| 237 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 238 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 239 |
+
mask[start_frame:condition_end] = 1.0
|
| 240 |
+
|
| 241 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 242 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 243 |
+
return camera_embedding.to(torch.bfloat16)
|
| 244 |
+
|
| 245 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 246 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 247 |
+
time_compression_ratio = 4
|
| 248 |
+
|
| 249 |
+
# 计算FramePack实际需要的camera帧数
|
| 250 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 251 |
+
|
| 252 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 253 |
+
print("🔧 使用OpenX真实camera数据")
|
| 254 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 255 |
+
|
| 256 |
+
# 确保生成足够长的camera序列
|
| 257 |
+
max_needed_frames = max(
|
| 258 |
+
start_frame + current_history_length + new_frames,
|
| 259 |
+
framepack_needed_frames,
|
| 260 |
+
30
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 264 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 265 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 266 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 267 |
+
|
| 268 |
+
relative_poses = []
|
| 269 |
+
for i in range(max_needed_frames):
|
| 270 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 271 |
+
frame_idx = i * time_compression_ratio
|
| 272 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 273 |
+
|
| 274 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 275 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 276 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 277 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 278 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 279 |
+
else:
|
| 280 |
+
# 超出范围,使用零运动
|
| 281 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 282 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 283 |
+
|
| 284 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 285 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 286 |
+
|
| 287 |
+
# 创建对应长度的mask序列
|
| 288 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 289 |
+
# 从start_frame到current_history_length标记为condition
|
| 290 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 291 |
+
mask[start_frame:condition_end] = 1.0
|
| 292 |
+
|
| 293 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 294 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 295 |
+
return camera_embedding.to(torch.bfloat16)
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
print("🔧 使用OpenX合成camera数据")
|
| 299 |
+
|
| 300 |
+
max_needed_frames = max(
|
| 301 |
+
start_frame + current_history_length + new_frames,
|
| 302 |
+
framepack_needed_frames,
|
| 303 |
+
30
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 307 |
+
relative_poses = []
|
| 308 |
+
for i in range(max_needed_frames):
|
| 309 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 310 |
+
# 模拟机器人手臂的精细操作运动
|
| 311 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 312 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 313 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 314 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 315 |
+
|
| 316 |
+
pose = np.eye(4, dtype=np.float32)
|
| 317 |
+
|
| 318 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 319 |
+
# 绕X轴旋转(roll)
|
| 320 |
+
cos_roll = np.cos(roll_per_frame)
|
| 321 |
+
sin_roll = np.sin(roll_per_frame)
|
| 322 |
+
# 绕Y轴旋转(pitch)
|
| 323 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 324 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 325 |
+
# 绕Z轴旋转(yaw)
|
| 326 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 327 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 328 |
+
|
| 329 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 330 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 331 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 332 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 333 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 334 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 335 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 336 |
+
pose[2, 0] = -sin_pitch
|
| 337 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 338 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 339 |
+
|
| 340 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 341 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 342 |
+
pose[1, 3] = forward_speed * 0.3 # Y��向轻微移动
|
| 343 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 344 |
+
|
| 345 |
+
relative_pose = pose[:3, :]
|
| 346 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 347 |
+
|
| 348 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 349 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 350 |
+
|
| 351 |
+
# 创建对应长度的mask序列
|
| 352 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 353 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 354 |
+
mask[start_frame:condition_end] = 1.0
|
| 355 |
+
|
| 356 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 357 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 358 |
+
return camera_embedding.to(torch.bfloat16)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 362 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 363 |
+
time_compression_ratio = 4
|
| 364 |
+
|
| 365 |
+
# 计算FramePack实际需要的camera帧数
|
| 366 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 367 |
+
|
| 368 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 369 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 370 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 371 |
+
|
| 372 |
+
if len(keyframe_poses) == 0:
|
| 373 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 374 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 375 |
+
|
| 376 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 377 |
+
|
| 378 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 379 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 380 |
+
mask[start_frame:condition_end] = 1.0
|
| 381 |
+
|
| 382 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 383 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 384 |
+
return camera_embedding.to(torch.bfloat16)
|
| 385 |
+
|
| 386 |
+
# 使用第一个pose作为参考
|
| 387 |
+
reference_pose = keyframe_poses[0]
|
| 388 |
+
|
| 389 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 390 |
+
|
| 391 |
+
pose_vecs = []
|
| 392 |
+
for i in range(max_needed_frames):
|
| 393 |
+
if i < len(keyframe_poses):
|
| 394 |
+
current_pose = keyframe_poses[i]
|
| 395 |
+
|
| 396 |
+
# 计算相对位移
|
| 397 |
+
translation = torch.tensor(
|
| 398 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 399 |
+
dtype=torch.float32
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# 计算相对旋转(简化版本)
|
| 403 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 404 |
+
|
| 405 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 406 |
+
else:
|
| 407 |
+
# 超出范围,使用零pose
|
| 408 |
+
pose_vec = torch.cat([
|
| 409 |
+
torch.zeros(3, dtype=torch.float32),
|
| 410 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 411 |
+
], dim=0) # [7D]
|
| 412 |
+
|
| 413 |
+
pose_vecs.append(pose_vec)
|
| 414 |
+
|
| 415 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 416 |
+
|
| 417 |
+
# 创建mask
|
| 418 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 419 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 420 |
+
mask[start_frame:condition_end] = 1.0
|
| 421 |
+
|
| 422 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 423 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 424 |
+
return camera_embedding.to(torch.bfloat16)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 428 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 429 |
+
|
| 430 |
+
# 创建合成运动序列
|
| 431 |
+
pose_vecs = []
|
| 432 |
+
for i in range(max_needed_frames):
|
| 433 |
+
# 简单的前进运动
|
| 434 |
+
translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
|
| 435 |
+
rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
|
| 436 |
+
|
| 437 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 438 |
+
pose_vecs.append(pose_vec)
|
| 439 |
+
|
| 440 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 441 |
+
|
| 442 |
+
# 创建mask
|
| 443 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 444 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 445 |
+
mask[start_frame:condition_end] = 1.0
|
| 446 |
+
|
| 447 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 448 |
+
print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
|
| 449 |
+
return camera_embedding.to(torch.bfloat16)
|
| 450 |
+
|
| 451 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 452 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 453 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 454 |
+
C, T, H, W = history_latents.shape
|
| 455 |
+
|
| 456 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 457 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 458 |
+
indices = torch.arange(0, total_indices_length)
|
| 459 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 460 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 461 |
+
indices.split(split_sizes, dim=0)
|
| 462 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 463 |
+
|
| 464 |
+
# 检查camera长度是否足够
|
| 465 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 466 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 467 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 468 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 469 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 470 |
+
|
| 471 |
+
# 从完整camera序列中选取对应部分
|
| 472 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 473 |
+
|
| 474 |
+
# 根据当前history length重新设置mask
|
| 475 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 476 |
+
|
| 477 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 478 |
+
if T > 0:
|
| 479 |
+
available_frames = min(T, 19)
|
| 480 |
+
start_pos = 19 - available_frames
|
| 481 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 482 |
+
|
| 483 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 484 |
+
print(f" - 历史帧数: {T}")
|
| 485 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 486 |
+
print(f" - 模态类型: {modality_type}")
|
| 487 |
+
|
| 488 |
+
# 处理latents
|
| 489 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 490 |
+
|
| 491 |
+
if T > 0:
|
| 492 |
+
available_frames = min(T, 19)
|
| 493 |
+
start_pos = 19 - available_frames
|
| 494 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 495 |
+
|
| 496 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 497 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 498 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 499 |
+
|
| 500 |
+
if T > 0:
|
| 501 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 502 |
+
else:
|
| 503 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 504 |
+
|
| 505 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 506 |
+
|
| 507 |
+
return {
|
| 508 |
+
'latent_indices': latent_indices,
|
| 509 |
+
'clean_latents': clean_latents,
|
| 510 |
+
'clean_latents_2x': clean_latents_2x,
|
| 511 |
+
'clean_latents_4x': clean_latents_4x,
|
| 512 |
+
'clean_latent_indices': clean_latent_indices,
|
| 513 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 514 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 515 |
+
'camera_embedding': combined_camera,
|
| 516 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 517 |
+
'current_length': T,
|
| 518 |
+
'next_length': T + target_frames_to_generate
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def inference_moe_framepack_sliding_window(
|
| 523 |
+
condition_pth_path,
|
| 524 |
+
dit_path,
|
| 525 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 526 |
+
start_frame=0,
|
| 527 |
+
initial_condition_frames=8,
|
| 528 |
+
frames_per_generation=4,
|
| 529 |
+
total_frames_to_generate=32,
|
| 530 |
+
max_history_frames=49,
|
| 531 |
+
device="cuda",
|
| 532 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 533 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 534 |
+
use_real_poses=True,
|
| 535 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 536 |
+
# CFG参数
|
| 537 |
+
use_camera_cfg=True,
|
| 538 |
+
camera_guidance_scale=2.0,
|
| 539 |
+
text_guidance_scale=1.0,
|
| 540 |
+
# MoE参数
|
| 541 |
+
moe_num_experts=4,
|
| 542 |
+
moe_top_k=2,
|
| 543 |
+
moe_hidden_dim=None
|
| 544 |
+
):
|
| 545 |
+
"""
|
| 546 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 547 |
+
"""
|
| 548 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 549 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 550 |
+
print(f"模态类型: {modality_type}")
|
| 551 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 552 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 553 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 554 |
+
|
| 555 |
+
# 1. 模型初始化
|
| 556 |
+
replace_dit_model_in_manager()
|
| 557 |
+
|
| 558 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 559 |
+
model_manager.load_models([
|
| 560 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 561 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 562 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 563 |
+
])
|
| 564 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 565 |
+
|
| 566 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 567 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 568 |
+
for block in pipe.dit.blocks:
|
| 569 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 570 |
+
block.projector = nn.Linear(dim, dim)
|
| 571 |
+
block.cam_encoder.weight.data.zero_()
|
| 572 |
+
block.cam_encoder.bias.data.zero_()
|
| 573 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 574 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 575 |
+
|
| 576 |
+
# 3. 添加FramePack组件
|
| 577 |
+
add_framepack_components(pipe.dit)
|
| 578 |
+
|
| 579 |
+
# 4. 添加MoE组件
|
| 580 |
+
moe_config = {
|
| 581 |
+
"num_experts": moe_num_experts,
|
| 582 |
+
"top_k": moe_top_k,
|
| 583 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 584 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 585 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 586 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 587 |
+
}
|
| 588 |
+
add_moe_components(pipe.dit, moe_config)
|
| 589 |
+
|
| 590 |
+
# 5. 加载训练好的权重
|
| 591 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 592 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 593 |
+
pipe = pipe.to(device)
|
| 594 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 595 |
+
|
| 596 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 597 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 598 |
+
|
| 599 |
+
pipe.scheduler.set_timesteps(50)
|
| 600 |
+
|
| 601 |
+
# 6. 加载初始条件
|
| 602 |
+
print("Loading initial condition frames...")
|
| 603 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 604 |
+
condition_pth_path,
|
| 605 |
+
start_frame=start_frame,
|
| 606 |
+
num_frames=initial_condition_frames
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# 空间裁剪
|
| 610 |
+
target_height, target_width = 60, 104
|
| 611 |
+
C, T, H, W = initial_latents.shape
|
| 612 |
+
|
| 613 |
+
if H > target_height or W > target_width:
|
| 614 |
+
h_start = (H - target_height) // 2
|
| 615 |
+
w_start = (W - target_width) // 2
|
| 616 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 617 |
+
H, W = target_height, target_width
|
| 618 |
+
|
| 619 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 620 |
+
|
| 621 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 622 |
+
|
| 623 |
+
# 7. 编码prompt - 支持CFG
|
| 624 |
+
if text_guidance_scale > 1.0:
|
| 625 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 626 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 627 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 628 |
+
else:
|
| 629 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 630 |
+
prompt_emb_neg = None
|
| 631 |
+
print("不使用Text CFG")
|
| 632 |
+
|
| 633 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 634 |
+
scene_info = None
|
| 635 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 636 |
+
with open(scene_info_path, 'r') as f:
|
| 637 |
+
scene_info = json.load(f)
|
| 638 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 639 |
+
|
| 640 |
+
# 9. 预生成完整的camera embedding序列
|
| 641 |
+
if modality_type == "sekai":
|
| 642 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 643 |
+
encoded_data.get('cam_emb', None),
|
| 644 |
+
0,
|
| 645 |
+
max_history_frames,
|
| 646 |
+
0,
|
| 647 |
+
0,
|
| 648 |
+
use_real_poses=use_real_poses
|
| 649 |
+
).to(device, dtype=model_dtype)
|
| 650 |
+
elif modality_type == "nuscenes":
|
| 651 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 652 |
+
scene_info,
|
| 653 |
+
0,
|
| 654 |
+
max_history_frames,
|
| 655 |
+
0
|
| 656 |
+
).to(device, dtype=model_dtype)
|
| 657 |
+
elif modality_type == "openx":
|
| 658 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 659 |
+
encoded_data,
|
| 660 |
+
0,
|
| 661 |
+
max_history_frames,
|
| 662 |
+
0,
|
| 663 |
+
use_real_poses=use_real_poses
|
| 664 |
+
).to(device, dtype=model_dtype)
|
| 665 |
+
else:
|
| 666 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 667 |
+
|
| 668 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 669 |
+
|
| 670 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 671 |
+
if use_camera_cfg:
|
| 672 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 673 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 674 |
+
|
| 675 |
+
# 11. 滑动窗口生成循环
|
| 676 |
+
total_generated = 0
|
| 677 |
+
all_generated_frames = []
|
| 678 |
+
|
| 679 |
+
while total_generated < total_frames_to_generate:
|
| 680 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 681 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 682 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 683 |
+
|
| 684 |
+
# FramePack数据准备 - MoE版本
|
| 685 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 686 |
+
history_latents,
|
| 687 |
+
current_generation,
|
| 688 |
+
camera_embedding_full,
|
| 689 |
+
start_frame,
|
| 690 |
+
modality_type,
|
| 691 |
+
max_history_frames
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# 准备输入
|
| 695 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 696 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 697 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 698 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 699 |
+
|
| 700 |
+
# 准备modality_inputs
|
| 701 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 702 |
+
|
| 703 |
+
# 为CFG准备无条件camera embedding
|
| 704 |
+
if use_camera_cfg:
|
| 705 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 706 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 707 |
+
|
| 708 |
+
# 索引处理
|
| 709 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 710 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 711 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 712 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 713 |
+
|
| 714 |
+
# 初始化要生成的latents
|
| 715 |
+
new_latents = torch.randn(
|
| 716 |
+
1, C, current_generation, H, W,
|
| 717 |
+
device=device, dtype=model_dtype
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 721 |
+
|
| 722 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 723 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 724 |
+
|
| 725 |
+
# 去噪循环 - 支持CFG
|
| 726 |
+
timesteps = pipe.scheduler.timesteps
|
| 727 |
+
|
| 728 |
+
for i, timestep in enumerate(timesteps):
|
| 729 |
+
if i % 10 == 0:
|
| 730 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 731 |
+
|
| 732 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 733 |
+
|
| 734 |
+
with torch.no_grad():
|
| 735 |
+
# CFG推理
|
| 736 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 737 |
+
# 条件预测(有camera)
|
| 738 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 739 |
+
new_latents,
|
| 740 |
+
timestep=timestep_tensor,
|
| 741 |
+
cam_emb=camera_embedding,
|
| 742 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 743 |
+
latent_indices=latent_indices,
|
| 744 |
+
clean_latents=clean_latents,
|
| 745 |
+
clean_latent_indices=clean_latent_indices,
|
| 746 |
+
clean_latents_2x=clean_latents_2x,
|
| 747 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 748 |
+
clean_latents_4x=clean_latents_4x,
|
| 749 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 750 |
+
**prompt_emb_pos,
|
| 751 |
+
**extra_input
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# 无条件预测(无camera)
|
| 755 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 756 |
+
new_latents,
|
| 757 |
+
timestep=timestep_tensor,
|
| 758 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 759 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 760 |
+
latent_indices=latent_indices,
|
| 761 |
+
clean_latents=clean_latents,
|
| 762 |
+
clean_latent_indices=clean_latent_indices,
|
| 763 |
+
clean_latents_2x=clean_latents_2x,
|
| 764 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 765 |
+
clean_latents_4x=clean_latents_4x,
|
| 766 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 767 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 768 |
+
**extra_input
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Camera CFG
|
| 772 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 773 |
+
|
| 774 |
+
# 如果同时使用Text CFG
|
| 775 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 776 |
+
noise_pred_text_uncond, moe_loss = pipe.dit(
|
| 777 |
+
new_latents,
|
| 778 |
+
timestep=timestep_tensor,
|
| 779 |
+
cam_emb=camera_embedding,
|
| 780 |
+
modality_inputs=modality_inputs,
|
| 781 |
+
latent_indices=latent_indices,
|
| 782 |
+
clean_latents=clean_latents,
|
| 783 |
+
clean_latent_indices=clean_latent_indices,
|
| 784 |
+
clean_latents_2x=clean_latents_2x,
|
| 785 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 786 |
+
clean_latents_4x=clean_latents_4x,
|
| 787 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 788 |
+
**prompt_emb_neg,
|
| 789 |
+
**extra_input
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 793 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 794 |
+
|
| 795 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 796 |
+
# 只使用Text CFG
|
| 797 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 798 |
+
new_latents,
|
| 799 |
+
timestep=timestep_tensor,
|
| 800 |
+
cam_emb=camera_embedding,
|
| 801 |
+
modality_inputs=modality_inputs,
|
| 802 |
+
latent_indices=latent_indices,
|
| 803 |
+
clean_latents=clean_latents,
|
| 804 |
+
clean_latent_indices=clean_latent_indices,
|
| 805 |
+
clean_latents_2x=clean_latents_2x,
|
| 806 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 807 |
+
clean_latents_4x=clean_latents_4x,
|
| 808 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 809 |
+
**prompt_emb_pos,
|
| 810 |
+
**extra_input
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 814 |
+
new_latents,
|
| 815 |
+
timestep=timestep_tensor,
|
| 816 |
+
cam_emb=camera_embedding,
|
| 817 |
+
modality_inputs=modality_inputs,
|
| 818 |
+
latent_indices=latent_indices,
|
| 819 |
+
clean_latents=clean_latents,
|
| 820 |
+
clean_latent_indices=clean_latent_indices,
|
| 821 |
+
clean_latents_2x=clean_latents_2x,
|
| 822 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 823 |
+
clean_latents_4x=clean_latents_4x,
|
| 824 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 825 |
+
**prompt_emb_neg,
|
| 826 |
+
**extra_input
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 830 |
+
|
| 831 |
+
else:
|
| 832 |
+
# 标准推理(无CFG)
|
| 833 |
+
noise_pred, moe_loss = pipe.dit(
|
| 834 |
+
new_latents,
|
| 835 |
+
timestep=timestep_tensor,
|
| 836 |
+
cam_emb=camera_embedding,
|
| 837 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 838 |
+
latent_indices=latent_indices,
|
| 839 |
+
clean_latents=clean_latents,
|
| 840 |
+
clean_latent_indices=clean_latent_indices,
|
| 841 |
+
clean_latents_2x=clean_latents_2x,
|
| 842 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 843 |
+
clean_latents_4x=clean_latents_4x,
|
| 844 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 845 |
+
**prompt_emb_pos,
|
| 846 |
+
**extra_input
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 850 |
+
|
| 851 |
+
# 更新历史
|
| 852 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 853 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 854 |
+
|
| 855 |
+
# 维护滑动窗口
|
| 856 |
+
if history_latents.shape[1] > max_history_frames:
|
| 857 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 858 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 859 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 860 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 861 |
+
|
| 862 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 863 |
+
|
| 864 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 865 |
+
total_generated += current_generation
|
| 866 |
+
|
| 867 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 868 |
+
|
| 869 |
+
# 12. 解码和保存
|
| 870 |
+
print("\n🔧 解码生成的视频...")
|
| 871 |
+
|
| 872 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 873 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 874 |
+
|
| 875 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 876 |
+
|
| 877 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 878 |
+
|
| 879 |
+
print(f"Saving video to {output_path}")
|
| 880 |
+
|
| 881 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 882 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 883 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 884 |
+
|
| 885 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 886 |
+
for frame in video_np:
|
| 887 |
+
writer.append_data(frame)
|
| 888 |
+
|
| 889 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 890 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 891 |
+
print(f"使用模态: {modality_type}")
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def main():
|
| 895 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 896 |
+
|
| 897 |
+
# 基��参数
|
| 898 |
+
parser.add_argument("--condition_pth", type=str,
|
| 899 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 900 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 901 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 902 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 903 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 904 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 905 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 906 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=40)
|
| 907 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 908 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 909 |
+
parser.add_argument("--dit_path", type=str,
|
| 910 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1000_moe.ckpt")
|
| 911 |
+
parser.add_argument("--output_path", type=str,
|
| 912 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 913 |
+
parser.add_argument("--prompt", type=str,
|
| 914 |
+
default="A drone flying scene in a game world")
|
| 915 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 916 |
+
|
| 917 |
+
# 模态类型参数
|
| 918 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 919 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 920 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 921 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 922 |
+
|
| 923 |
+
# CFG参数
|
| 924 |
+
parser.add_argument("--use_camera_cfg", default=True,
|
| 925 |
+
help="使用Camera CFG")
|
| 926 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 927 |
+
help="Camera guidance scale for CFG")
|
| 928 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 929 |
+
help="Text guidance scale for CFG")
|
| 930 |
+
|
| 931 |
+
# MoE参数
|
| 932 |
+
parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
|
| 933 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 934 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 935 |
+
|
| 936 |
+
args = parser.parse_args()
|
| 937 |
+
|
| 938 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 939 |
+
print(f"模态类型: {args.modality_type}")
|
| 940 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 941 |
+
if args.use_camera_cfg:
|
| 942 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 943 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 944 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 945 |
+
|
| 946 |
+
# 验证NuScenes参数
|
| 947 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 948 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 949 |
+
|
| 950 |
+
inference_moe_framepack_sliding_window(
|
| 951 |
+
condition_pth_path=args.condition_pth,
|
| 952 |
+
dit_path=args.dit_path,
|
| 953 |
+
output_path=args.output_path,
|
| 954 |
+
start_frame=args.start_frame,
|
| 955 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 956 |
+
frames_per_generation=args.frames_per_generation,
|
| 957 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 958 |
+
max_history_frames=args.max_history_frames,
|
| 959 |
+
device=args.device,
|
| 960 |
+
prompt=args.prompt,
|
| 961 |
+
modality_type=args.modality_type,
|
| 962 |
+
use_real_poses=args.use_real_poses,
|
| 963 |
+
scene_info_path=args.scene_info_path,
|
| 964 |
+
# CFG参数
|
| 965 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 966 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 967 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 968 |
+
# MoE参数
|
| 969 |
+
moe_num_experts=args.moe_num_experts,
|
| 970 |
+
moe_top_k=args.moe_top_k,
|
| 971 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
if __name__ == "__main__":
|
| 976 |
+
main()
|
scripts/infer_nus.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
import argparse
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from pose_classifier import PoseClassifier
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_video_frames(video_path, num_frames=20, height=900, width=1600):
|
| 16 |
+
"""Load video frames and preprocess them"""
|
| 17 |
+
frame_process = v2.Compose([
|
| 18 |
+
# v2.CenterCrop(size=(height, width)),
|
| 19 |
+
# v2.Resize(size=(height, width), antialias=True),
|
| 20 |
+
v2.ToTensor(),
|
| 21 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
def crop_and_resize(image):
|
| 25 |
+
w, h = image.size
|
| 26 |
+
# scale = max(width / w, height / h)
|
| 27 |
+
image = v2.functional.resize(
|
| 28 |
+
image,
|
| 29 |
+
(round(480), round(832)),
|
| 30 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 31 |
+
)
|
| 32 |
+
return image
|
| 33 |
+
|
| 34 |
+
reader = imageio.get_reader(video_path)
|
| 35 |
+
frames = []
|
| 36 |
+
|
| 37 |
+
for i, frame_data in enumerate(reader):
|
| 38 |
+
if i >= num_frames:
|
| 39 |
+
break
|
| 40 |
+
frame = Image.fromarray(frame_data)
|
| 41 |
+
frame = crop_and_resize(frame)
|
| 42 |
+
frame = frame_process(frame)
|
| 43 |
+
frames.append(frame)
|
| 44 |
+
|
| 45 |
+
reader.close()
|
| 46 |
+
|
| 47 |
+
if len(frames) == 0:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
frames = torch.stack(frames, dim=0)
|
| 51 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 52 |
+
return frames
|
| 53 |
+
|
| 54 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 55 |
+
"""计算相对旋转四元数"""
|
| 56 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 57 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
# 计算参考旋转的逆 (q_ref^-1)
|
| 60 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 61 |
+
|
| 62 |
+
# 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
|
| 63 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 64 |
+
w2, x2, y2, z2 = q_current
|
| 65 |
+
|
| 66 |
+
relative_rotation = torch.tensor([
|
| 67 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 68 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 69 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 70 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 71 |
+
])
|
| 72 |
+
|
| 73 |
+
return relative_rotation
|
| 74 |
+
|
| 75 |
+
def generate_direction_poses(direction="left", target_frames=10, condition_frames=20):
|
| 76 |
+
"""
|
| 77 |
+
根据指定方向生成pose类别embedding,包含condition和target帧
|
| 78 |
+
Args:
|
| 79 |
+
direction: 'forward', 'backward', 'left_turn', 'right_turn'
|
| 80 |
+
target_frames: 目标帧数
|
| 81 |
+
condition_frames: 条件帧数
|
| 82 |
+
"""
|
| 83 |
+
classifier = PoseClassifier()
|
| 84 |
+
|
| 85 |
+
total_frames = condition_frames + target_frames
|
| 86 |
+
print(f"conditon{condition_frames}")
|
| 87 |
+
print(f"target{target_frames}")
|
| 88 |
+
poses = []
|
| 89 |
+
|
| 90 |
+
# 🔧 生成condition帧的pose(相对稳定的前向运动)
|
| 91 |
+
for i in range(condition_frames):
|
| 92 |
+
t = i / max(1, condition_frames - 1) # 0 to 1
|
| 93 |
+
|
| 94 |
+
# condition帧保持相对稳定的前向运动
|
| 95 |
+
translation = [-t * 0.5, 0.0, 0.0] # 缓慢前进
|
| 96 |
+
rotation = [1.0, 0.0, 0.0, 0.0] # 无旋转
|
| 97 |
+
frame_type = 0.0 # condition
|
| 98 |
+
|
| 99 |
+
pose_vec = translation + rotation + [frame_type] # 8D vector
|
| 100 |
+
poses.append(pose_vec)
|
| 101 |
+
|
| 102 |
+
# 🔧 生成target帧的pose(根据指定方向)
|
| 103 |
+
for i in range(target_frames):
|
| 104 |
+
t = i / max(1, target_frames - 1) # 0 to 1
|
| 105 |
+
|
| 106 |
+
if direction == "forward":
|
| 107 |
+
# 前进:x负方向移动,无旋转
|
| 108 |
+
translation = [-(condition_frames * 0.5 + t * 2.0), 0.0, 0.0]
|
| 109 |
+
rotation = [1.0, 0.0, 0.0, 0.0] # 单位四元数
|
| 110 |
+
|
| 111 |
+
elif direction == "backward":
|
| 112 |
+
# 后退:x正方向移动,无旋转
|
| 113 |
+
translation = [-(condition_frames * 0.5) + t * 2.0, 0.0, 0.0]
|
| 114 |
+
rotation = [1.0, 0.0, 0.0, 0.0]
|
| 115 |
+
|
| 116 |
+
elif direction == "left_turn":
|
| 117 |
+
# 左转:前进 + 绕z轴正向旋转
|
| 118 |
+
translation = [-(condition_frames * 0.5 + t * 1.5), t * 0.5, 0.0] # 前进并稍微左移
|
| 119 |
+
yaw = t * 0.3 # 左转角度(弧度)
|
| 120 |
+
rotation = [
|
| 121 |
+
np.cos(yaw/2), # w
|
| 122 |
+
0.0, # x
|
| 123 |
+
0.0, # y
|
| 124 |
+
np.sin(yaw/2) # z (左转为正)
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
elif direction == "right_turn":
|
| 128 |
+
# 右转:前进 + 绕z轴负向旋转
|
| 129 |
+
translation = [-(condition_frames * 0.5 + t * 1.5), -t * 0.5, 0.0] # 前进并稍微右移
|
| 130 |
+
yaw = -t * 0.3 # 右转角度(弧度)
|
| 131 |
+
rotation = [
|
| 132 |
+
np.cos(abs(yaw)/2), # w
|
| 133 |
+
0.0, # x
|
| 134 |
+
0.0, # y
|
| 135 |
+
np.sin(yaw/2) # z (右转为负)
|
| 136 |
+
]
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unknown direction: {direction}")
|
| 139 |
+
|
| 140 |
+
frame_type = 1.0 # target
|
| 141 |
+
pose_vec = translation + rotation + [frame_type] # 8D vector
|
| 142 |
+
poses.append(pose_vec)
|
| 143 |
+
|
| 144 |
+
pose_sequence = torch.tensor(poses, dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
# 🔧 只对target部分进行分类(前7维,去掉frame type)
|
| 147 |
+
target_pose_sequence = pose_sequence[condition_frames:, :7]
|
| 148 |
+
|
| 149 |
+
# 🔧 使用增强的embedding生成方法
|
| 150 |
+
condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # condition都是forward
|
| 151 |
+
target_classes = classifier.classify_pose_sequence(target_pose_sequence)
|
| 152 |
+
full_classes = torch.cat([condition_classes, target_classes], dim=0)
|
| 153 |
+
|
| 154 |
+
# 创建增强的embedding
|
| 155 |
+
class_embeddings = create_enhanced_class_embedding_for_inference(
|
| 156 |
+
full_classes, pose_sequence, embed_dim=512
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
print(f"Generated {direction} poses:")
|
| 160 |
+
print(f" Total frames: {total_frames} (condition: {condition_frames}, target: {target_frames})")
|
| 161 |
+
analysis = classifier.analyze_pose_sequence(target_pose_sequence)
|
| 162 |
+
print(f" Target class distribution: {analysis['class_distribution']}")
|
| 163 |
+
print(f" Target motion segments: {len(analysis['motion_segments'])}")
|
| 164 |
+
|
| 165 |
+
return class_embeddings
|
| 166 |
+
|
| 167 |
+
def create_enhanced_class_embedding_for_inference(class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor:
|
| 168 |
+
"""推理时创建增强的类别embedding"""
|
| 169 |
+
num_classes = 4
|
| 170 |
+
num_frames = len(class_labels)
|
| 171 |
+
|
| 172 |
+
# 基础的方向embedding
|
| 173 |
+
direction_vectors = torch.tensor([
|
| 174 |
+
[1.0, 0.0, 0.0, 0.0], # forward
|
| 175 |
+
[-1.0, 0.0, 0.0, 0.0], # backward
|
| 176 |
+
[0.0, 1.0, 0.0, 0.0], # left_turn
|
| 177 |
+
[0.0, -1.0, 0.0, 0.0], # right_turn
|
| 178 |
+
], dtype=torch.float32)
|
| 179 |
+
|
| 180 |
+
# One-hot编码
|
| 181 |
+
one_hot = torch.zeros(num_frames, num_classes)
|
| 182 |
+
one_hot.scatter_(1, class_labels.unsqueeze(1), 1)
|
| 183 |
+
|
| 184 |
+
# 基于方向向量的基础embedding
|
| 185 |
+
base_embeddings = one_hot @ direction_vectors # [num_frames, 4]
|
| 186 |
+
|
| 187 |
+
# 添加frame type信息
|
| 188 |
+
frame_types = pose_sequence[:, -1] # 最后一维是frame type
|
| 189 |
+
frame_type_embeddings = torch.zeros(num_frames, 2)
|
| 190 |
+
frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition
|
| 191 |
+
frame_type_embeddings[:, 1] = (frame_types == 1).float() # target
|
| 192 |
+
|
| 193 |
+
# 添加pose的几何信息
|
| 194 |
+
translations = pose_sequence[:, :3] # [num_frames, 3]
|
| 195 |
+
rotations = pose_sequence[:, 3:7] # [num_frames, 4]
|
| 196 |
+
|
| 197 |
+
# 组合所有特征
|
| 198 |
+
combined_features = torch.cat([
|
| 199 |
+
base_embeddings, # [num_frames, 4]
|
| 200 |
+
frame_type_embeddings, # [num_frames, 2]
|
| 201 |
+
translations, # [num_frames, 3]
|
| 202 |
+
rotations, # [num_frames, 4]
|
| 203 |
+
], dim=1) # [num_frames, 13]
|
| 204 |
+
|
| 205 |
+
# 扩展到目标维度
|
| 206 |
+
if embed_dim > 13:
|
| 207 |
+
expand_matrix = torch.randn(13, embed_dim) * 0.1
|
| 208 |
+
expand_matrix[:13, :13] = torch.eye(13)
|
| 209 |
+
embeddings = combined_features @ expand_matrix
|
| 210 |
+
else:
|
| 211 |
+
embeddings = combined_features[:, :embed_dim]
|
| 212 |
+
|
| 213 |
+
return embeddings
|
| 214 |
+
|
| 215 |
+
def generate_poses_from_file(poses_path, target_frames=10):
|
| 216 |
+
"""从poses.json文件生成类别embedding"""
|
| 217 |
+
classifier = PoseClassifier()
|
| 218 |
+
|
| 219 |
+
with open(poses_path, 'r') as f:
|
| 220 |
+
poses_data = json.load(f)
|
| 221 |
+
|
| 222 |
+
target_relative_poses = poses_data['target_relative_poses']
|
| 223 |
+
|
| 224 |
+
if not target_relative_poses:
|
| 225 |
+
print("No poses found in file, using forward direction")
|
| 226 |
+
return generate_direction_poses("forward", target_frames)
|
| 227 |
+
|
| 228 |
+
# 创建pose序列
|
| 229 |
+
pose_vecs = []
|
| 230 |
+
for i in range(target_frames):
|
| 231 |
+
if len(target_relative_poses) == 1:
|
| 232 |
+
pose_data = target_relative_poses[0]
|
| 233 |
+
else:
|
| 234 |
+
pose_idx = min(i * len(target_relative_poses) // target_frames,
|
| 235 |
+
len(target_relative_poses) - 1)
|
| 236 |
+
pose_data = target_relative_poses[pose_idx]
|
| 237 |
+
|
| 238 |
+
# 提取相对位移和旋转
|
| 239 |
+
translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
|
| 240 |
+
current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
|
| 241 |
+
reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
|
| 242 |
+
|
| 243 |
+
# 计算相对旋转
|
| 244 |
+
relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
|
| 245 |
+
|
| 246 |
+
# 组合为7D向量
|
| 247 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0)
|
| 248 |
+
pose_vecs.append(pose_vec)
|
| 249 |
+
|
| 250 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 251 |
+
|
| 252 |
+
# 使用分类器生成class embedding
|
| 253 |
+
class_embeddings = classifier.create_class_embedding(
|
| 254 |
+
classifier.classify_pose_sequence(pose_sequence),
|
| 255 |
+
embed_dim=512
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
print(f"Generated poses from file:")
|
| 259 |
+
analysis = classifier.analyze_pose_sequence(pose_sequence)
|
| 260 |
+
print(f" Class distribution: {analysis['class_distribution']}")
|
| 261 |
+
print(f" Motion segments: {len(analysis['motion_segments'])}")
|
| 262 |
+
|
| 263 |
+
return class_embeddings
|
| 264 |
+
|
| 265 |
+
def inference_nuscenes_video(
|
| 266 |
+
condition_video_path,
|
| 267 |
+
dit_path,
|
| 268 |
+
text_encoder_path,
|
| 269 |
+
vae_path,
|
| 270 |
+
output_path="nus/infer_results/output_nuscenes.mp4",
|
| 271 |
+
condition_frames=20,
|
| 272 |
+
target_frames=3,
|
| 273 |
+
height=900,
|
| 274 |
+
width=1600,
|
| 275 |
+
device="cuda",
|
| 276 |
+
prompt="A car driving scene captured by front camera",
|
| 277 |
+
poses_path=None,
|
| 278 |
+
direction="forward"
|
| 279 |
+
):
|
| 280 |
+
"""
|
| 281 |
+
使用方向类别控制的推理函数 - 支持condition和target pose区分
|
| 282 |
+
"""
|
| 283 |
+
os.makedirs(os.path.dirname(output_path),exist_ok=True)
|
| 284 |
+
|
| 285 |
+
print(f"Setting up models for {direction} movement...")
|
| 286 |
+
|
| 287 |
+
# 1. Load models (same as before)
|
| 288 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 289 |
+
model_manager.load_models([
|
| 290 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 291 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 292 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 293 |
+
])
|
| 294 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 295 |
+
|
| 296 |
+
# Add camera components to DiT
|
| 297 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 298 |
+
for block in pipe.dit.blocks:
|
| 299 |
+
block.cam_encoder = nn.Linear(512, dim) # 保持512维embedding
|
| 300 |
+
block.projector = nn.Linear(dim, dim)
|
| 301 |
+
block.cam_encoder.weight.data.zero_()
|
| 302 |
+
block.cam_encoder.bias.data.zero_()
|
| 303 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 304 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 305 |
+
|
| 306 |
+
# Load trained DiT weights
|
| 307 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 308 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 309 |
+
pipe = pipe.to(device)
|
| 310 |
+
pipe.scheduler.set_timesteps(50)
|
| 311 |
+
|
| 312 |
+
print("Loading condition video...")
|
| 313 |
+
|
| 314 |
+
# Load condition video
|
| 315 |
+
condition_video = load_video_frames(
|
| 316 |
+
condition_video_path,
|
| 317 |
+
num_frames=condition_frames,
|
| 318 |
+
height=height,
|
| 319 |
+
width=width
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if condition_video is None:
|
| 323 |
+
raise ValueError(f"Failed to load condition video from {condition_video_path}")
|
| 324 |
+
|
| 325 |
+
condition_video = condition_video.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 326 |
+
|
| 327 |
+
print("Processing poses...")
|
| 328 |
+
|
| 329 |
+
# 🔧 修改:生成包含condition和target的pose embedding
|
| 330 |
+
print(f"Generating {direction} movement poses...")
|
| 331 |
+
camera_embedding = generate_direction_poses(
|
| 332 |
+
direction=direction,
|
| 333 |
+
target_frames=target_frames,
|
| 334 |
+
condition_frames=int(condition_frames/4) # 压缩后的condition帧数
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
|
| 338 |
+
|
| 339 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 340 |
+
print(f"Generated poses for direction: {direction}")
|
| 341 |
+
|
| 342 |
+
print("Encoding inputs...")
|
| 343 |
+
|
| 344 |
+
# Encode text prompt
|
| 345 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 346 |
+
|
| 347 |
+
# Encode condition video
|
| 348 |
+
condition_latents = pipe.encode_video(condition_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))[0]
|
| 349 |
+
|
| 350 |
+
print("Generating video...")
|
| 351 |
+
|
| 352 |
+
# Generate target latents
|
| 353 |
+
batch_size = 1
|
| 354 |
+
channels = condition_latents.shape[0]
|
| 355 |
+
latent_height = condition_latents.shape[2]
|
| 356 |
+
latent_width = condition_latents.shape[3]
|
| 357 |
+
target_height, target_width = 60, 104 # 根据你的需求调整
|
| 358 |
+
|
| 359 |
+
if latent_height > target_height or latent_width > target_width:
|
| 360 |
+
# 中心裁剪
|
| 361 |
+
h_start = (latent_height - target_height) // 2
|
| 362 |
+
w_start = (latent_width - target_width) // 2
|
| 363 |
+
condition_latents = condition_latents[:, :,
|
| 364 |
+
h_start:h_start+target_height,
|
| 365 |
+
w_start:w_start+target_width]
|
| 366 |
+
latent_height = target_height
|
| 367 |
+
latent_width = target_width
|
| 368 |
+
condition_latents = condition_latents.to(device, dtype=pipe.torch_dtype)
|
| 369 |
+
condition_latents = condition_latents.unsqueeze(0)
|
| 370 |
+
condition_latents = condition_latents + 0.05 * torch.randn_like(condition_latents) # 添加少量噪声以增加多样性
|
| 371 |
+
|
| 372 |
+
# Initialize target latents with noise
|
| 373 |
+
target_latents = torch.randn(
|
| 374 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 375 |
+
device=device, dtype=pipe.torch_dtype
|
| 376 |
+
)
|
| 377 |
+
print(target_latents.shape)
|
| 378 |
+
print(camera_embedding.shape)
|
| 379 |
+
# Combine condition and target latents
|
| 380 |
+
combined_latents = torch.cat([condition_latents, target_latents], dim=2)
|
| 381 |
+
print(combined_latents.shape)
|
| 382 |
+
|
| 383 |
+
# Prepare extra inputs
|
| 384 |
+
extra_input = pipe.prepare_extra_input(combined_latents)
|
| 385 |
+
|
| 386 |
+
# Denoising loop
|
| 387 |
+
timesteps = pipe.scheduler.timesteps
|
| 388 |
+
|
| 389 |
+
for i, timestep in enumerate(timesteps):
|
| 390 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 391 |
+
|
| 392 |
+
# Prepare timestep
|
| 393 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 394 |
+
|
| 395 |
+
# Predict noise
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
noise_pred = pipe.dit(
|
| 398 |
+
combined_latents,
|
| 399 |
+
timestep=timestep_tensor,
|
| 400 |
+
cam_emb=camera_embedding,
|
| 401 |
+
**prompt_emb,
|
| 402 |
+
**extra_input
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Update only target part
|
| 406 |
+
target_noise_pred = noise_pred[:, :, int(condition_frames/4):, :, :]
|
| 407 |
+
target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
|
| 408 |
+
|
| 409 |
+
# Update combined latents
|
| 410 |
+
combined_latents[:, :, int(condition_frames/4):, :, :] = target_latents
|
| 411 |
+
|
| 412 |
+
print("Decoding video...")
|
| 413 |
+
|
| 414 |
+
# Decode final video
|
| 415 |
+
final_video = torch.cat([condition_latents, target_latents], dim=2)
|
| 416 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 417 |
+
|
| 418 |
+
# Save video
|
| 419 |
+
print(f"Saving video to {output_path}")
|
| 420 |
+
|
| 421 |
+
# Convert to numpy and save
|
| 422 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() # 转换为 Float32
|
| 423 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
|
| 424 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 425 |
+
|
| 426 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 427 |
+
for frame in video_np:
|
| 428 |
+
writer.append_data(frame)
|
| 429 |
+
|
| 430 |
+
print(f"Video generation completed! Saved to {output_path}")
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
parser = argparse.ArgumentParser(description="NuScenes Video Generation Inference with Direction Control")
|
| 434 |
+
parser.add_argument("--condition_video", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032/right.mp4",
|
| 435 |
+
help="Path to condition video")
|
| 436 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 437 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 438 |
+
help="Direction of camera movement")
|
| 439 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
|
| 440 |
+
help="Path to trained DiT checkpoint")
|
| 441 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 442 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 443 |
+
help="Path to text encoder")
|
| 444 |
+
parser.add_argument("--vae_path", type=str,
|
| 445 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 446 |
+
help="Path to VAE")
|
| 447 |
+
parser.add_argument("--output_path", type=str, default="nus/infer_results-15000/right_left.mp4",
|
| 448 |
+
help="Output video path")
|
| 449 |
+
parser.add_argument("--poses_path", type=str, default=None,
|
| 450 |
+
help="Path to poses.json file (optional, will use direction if not provided)")
|
| 451 |
+
parser.add_argument("--prompt", type=str,
|
| 452 |
+
default="A car driving scene captured by front camera",
|
| 453 |
+
help="Text prompt for generation")
|
| 454 |
+
parser.add_argument("--condition_frames", type=int, default=40,
|
| 455 |
+
help="Number of condition frames")
|
| 456 |
+
# 这个是原始帧数
|
| 457 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 458 |
+
help="Number of target frames to generate")
|
| 459 |
+
# 这个要除以4
|
| 460 |
+
parser.add_argument("--height", type=int, default=900,
|
| 461 |
+
help="Video height")
|
| 462 |
+
parser.add_argument("--width", type=int, default=1600,
|
| 463 |
+
help="Video width")
|
| 464 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 465 |
+
help="Device to run inference on")
|
| 466 |
+
|
| 467 |
+
args = parser.parse_args()
|
| 468 |
+
|
| 469 |
+
condition_video_path = args.condition_video
|
| 470 |
+
input_filename = os.path.basename(condition_video_path)
|
| 471 |
+
output_dir = "nus/infer_results"
|
| 472 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 473 |
+
|
| 474 |
+
# 🔧 修改:在输出文件名中包含方向信息
|
| 475 |
+
if args.output_path is None:
|
| 476 |
+
name_parts = os.path.splitext(input_filename)
|
| 477 |
+
output_filename = f"{name_parts[0]}_{args.direction}{name_parts[1]}"
|
| 478 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 479 |
+
else:
|
| 480 |
+
output_path = args.output_path
|
| 481 |
+
|
| 482 |
+
print(f"Output video will be saved to: {output_path}")
|
| 483 |
+
inference_nuscenes_video(
|
| 484 |
+
condition_video_path=args.condition_video,
|
| 485 |
+
dit_path=args.dit_path,
|
| 486 |
+
text_encoder_path=args.text_encoder_path,
|
| 487 |
+
vae_path=args.vae_path,
|
| 488 |
+
output_path=output_path,
|
| 489 |
+
condition_frames=args.condition_frames,
|
| 490 |
+
target_frames=args.target_frames,
|
| 491 |
+
height=args.height,
|
| 492 |
+
width=args.width,
|
| 493 |
+
device=args.device,
|
| 494 |
+
prompt=args.prompt,
|
| 495 |
+
poses_path=args.poses_path,
|
| 496 |
+
direction=args.direction # 🔧 新增
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
main()
|
scripts/infer_openx.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 3 |
+
from torchvision.transforms import v2
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import argparse
|
| 9 |
+
import numpy as np
|
| 10 |
+
import imageio
|
| 11 |
+
import copy
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 15 |
+
"""从pth文件加载预编码的视频数据"""
|
| 16 |
+
print(f"Loading encoded video from {pth_path}")
|
| 17 |
+
|
| 18 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 19 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 20 |
+
|
| 21 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 22 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 23 |
+
|
| 24 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 25 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 26 |
+
|
| 27 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 28 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 29 |
+
|
| 30 |
+
return condition_latents, encoded_data
|
| 31 |
+
|
| 32 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 33 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 34 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 35 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 36 |
+
|
| 37 |
+
if use_torch:
|
| 38 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 39 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 40 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 41 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 42 |
+
|
| 43 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 44 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 45 |
+
else:
|
| 46 |
+
if not isinstance(pose_a, np.ndarray):
|
| 47 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 48 |
+
if not isinstance(pose_b, np.ndarray):
|
| 49 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 50 |
+
|
| 51 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 52 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 53 |
+
|
| 54 |
+
return relative_pose
|
| 55 |
+
|
| 56 |
+
def replace_dit_model_in_manager():
|
| 57 |
+
"""在模型加载前替换DiT模型类"""
|
| 58 |
+
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
|
| 59 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 60 |
+
|
| 61 |
+
# 修改model_loader_configs中的配置
|
| 62 |
+
for i, config in enumerate(model_loader_configs):
|
| 63 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 64 |
+
|
| 65 |
+
# 检查是否包含wan_video_dit模型
|
| 66 |
+
if 'wan_video_dit' in model_names:
|
| 67 |
+
# 找到wan_video_dit的索引并替换为WanModelFuture
|
| 68 |
+
new_model_names = []
|
| 69 |
+
new_model_classes = []
|
| 70 |
+
|
| 71 |
+
for name, cls in zip(model_names, model_classes):
|
| 72 |
+
if name == 'wan_video_dit':
|
| 73 |
+
new_model_names.append(name) # 保持名称不变
|
| 74 |
+
new_model_classes.append(WanModelFuture) # 替换为新的类
|
| 75 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture")
|
| 76 |
+
else:
|
| 77 |
+
new_model_names.append(name)
|
| 78 |
+
new_model_classes.append(cls)
|
| 79 |
+
|
| 80 |
+
# 更新配置
|
| 81 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 82 |
+
|
| 83 |
+
def add_framepack_components(dit_model):
|
| 84 |
+
"""添加FramePack相关组件"""
|
| 85 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 86 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 87 |
+
|
| 88 |
+
class CleanXEmbedder(nn.Module):
|
| 89 |
+
def __init__(self, inner_dim):
|
| 90 |
+
super().__init__()
|
| 91 |
+
# 参考hunyuan_video_packed.py的设计
|
| 92 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 93 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 94 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 95 |
+
|
| 96 |
+
def forward(self, x, scale="1x"):
|
| 97 |
+
if scale == "1x":
|
| 98 |
+
return self.proj(x)
|
| 99 |
+
elif scale == "2x":
|
| 100 |
+
return self.proj_2x(x)
|
| 101 |
+
elif scale == "4x":
|
| 102 |
+
return self.proj_4x(x)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 105 |
+
|
| 106 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 107 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 108 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 109 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 110 |
+
|
| 111 |
+
def generate_openx_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 112 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 113 |
+
time_compression_ratio = 4
|
| 114 |
+
|
| 115 |
+
# 计算FramePack实际需要的camera帧数
|
| 116 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 117 |
+
|
| 118 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 119 |
+
print("🔧 使用真实OpenX camera数据")
|
| 120 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 121 |
+
|
| 122 |
+
# 确保生成足够长的camera序列
|
| 123 |
+
max_needed_frames = max(
|
| 124 |
+
start_frame + current_history_length + new_frames,
|
| 125 |
+
framepack_needed_frames,
|
| 126 |
+
30
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 130 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 131 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 132 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 133 |
+
|
| 134 |
+
relative_poses = []
|
| 135 |
+
for i in range(max_needed_frames):
|
| 136 |
+
# OpenX特有:每隔4帧
|
| 137 |
+
frame_idx = i * time_compression_ratio
|
| 138 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 139 |
+
|
| 140 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 141 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 142 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 143 |
+
relative_cam = compute_relative_pose(cam_prev, cam_next)
|
| 144 |
+
relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
|
| 145 |
+
else:
|
| 146 |
+
# 超出范围,使用零运动
|
| 147 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 148 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 149 |
+
|
| 150 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 151 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 152 |
+
|
| 153 |
+
# 创建对应长度的mask序列
|
| 154 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 155 |
+
# 从start_frame到current_history_length标记为condition
|
| 156 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 157 |
+
mask[start_frame:condition_end] = 1.0
|
| 158 |
+
|
| 159 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 160 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 161 |
+
return camera_embedding.to(torch.bfloat16)
|
| 162 |
+
|
| 163 |
+
else:
|
| 164 |
+
print("🔧 使用OpenX合成camera数据")
|
| 165 |
+
|
| 166 |
+
max_needed_frames = max(
|
| 167 |
+
start_frame + current_history_length + new_frames,
|
| 168 |
+
framepack_needed_frames,
|
| 169 |
+
30
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 173 |
+
relative_poses = []
|
| 174 |
+
for i in range(max_needed_frames):
|
| 175 |
+
# OpenX机器人操作模式 - 稳定的小幅度运动
|
| 176 |
+
# 模拟机器人手臂的精细操作
|
| 177 |
+
forward_speed = 0.001 # 每帧前进距离(很小,因为是精细操作)
|
| 178 |
+
lateral_motion = 0.0005 * np.sin(i * 0.05) # 轻微的左右移动
|
| 179 |
+
vertical_motion = 0.0003 * np.cos(i * 0.1) # 轻微的上下移动
|
| 180 |
+
|
| 181 |
+
# 旋转变化(模拟视角微调)
|
| 182 |
+
yaw_change = 0.01 * np.sin(i * 0.03) # 轻微的偏航
|
| 183 |
+
pitch_change = 0.008 * np.cos(i * 0.04) # 轻微的俯仰
|
| 184 |
+
|
| 185 |
+
pose = np.eye(4, dtype=np.float32)
|
| 186 |
+
|
| 187 |
+
# 旋转矩阵(绕Y轴和X轴的小角度旋转)
|
| 188 |
+
cos_yaw = np.cos(yaw_change)
|
| 189 |
+
sin_yaw = np.sin(yaw_change)
|
| 190 |
+
cos_pitch = np.cos(pitch_change)
|
| 191 |
+
sin_pitch = np.sin(pitch_change)
|
| 192 |
+
|
| 193 |
+
# 组合旋转(先pitch后yaw)
|
| 194 |
+
pose[0, 0] = cos_yaw
|
| 195 |
+
pose[0, 2] = sin_yaw
|
| 196 |
+
pose[1, 1] = cos_pitch
|
| 197 |
+
pose[1, 2] = -sin_pitch
|
| 198 |
+
pose[2, 0] = -sin_yaw
|
| 199 |
+
pose[2, 1] = sin_pitch
|
| 200 |
+
pose[2, 2] = cos_yaw * cos_pitch
|
| 201 |
+
|
| 202 |
+
# 平移(精细操作的小幅度移动)
|
| 203 |
+
pose[0, 3] = lateral_motion # X轴(左右)
|
| 204 |
+
pose[1, 3] = vertical_motion # Y轴(上下)
|
| 205 |
+
pose[2, 3] = -forward_speed # Z轴(前后,负值表示前进)
|
| 206 |
+
|
| 207 |
+
relative_pose = pose[:3, :]
|
| 208 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 209 |
+
|
| 210 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 211 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 212 |
+
|
| 213 |
+
# 创建对应长度的mask序列
|
| 214 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 215 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 216 |
+
mask[start_frame:condition_end] = 1.0
|
| 217 |
+
|
| 218 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 219 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 220 |
+
return camera_embedding.to(torch.bfloat16)
|
| 221 |
+
|
| 222 |
+
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
|
| 223 |
+
"""FramePack滑动��口机制 - OpenX版本"""
|
| 224 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 225 |
+
C, T, H, W = history_latents.shape
|
| 226 |
+
|
| 227 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 228 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 229 |
+
indices = torch.arange(0, total_indices_length)
|
| 230 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 231 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 232 |
+
indices.split(split_sizes, dim=0)
|
| 233 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 234 |
+
|
| 235 |
+
# 检查camera长度是否足够
|
| 236 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 237 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 238 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 239 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 240 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 241 |
+
|
| 242 |
+
# 从完整camera序列中选取对应部分
|
| 243 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 244 |
+
|
| 245 |
+
# 根据当前history length重新设置mask
|
| 246 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 247 |
+
|
| 248 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 249 |
+
if T > 0:
|
| 250 |
+
available_frames = min(T, 19)
|
| 251 |
+
start_pos = 19 - available_frames
|
| 252 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 253 |
+
|
| 254 |
+
print(f"🔧 OpenX Camera mask更新:")
|
| 255 |
+
print(f" - 历史帧数: {T}")
|
| 256 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 257 |
+
|
| 258 |
+
# 处理latents
|
| 259 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 260 |
+
|
| 261 |
+
if T > 0:
|
| 262 |
+
available_frames = min(T, 19)
|
| 263 |
+
start_pos = 19 - available_frames
|
| 264 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 265 |
+
|
| 266 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 267 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 268 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 269 |
+
|
| 270 |
+
if T > 0:
|
| 271 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 272 |
+
else:
|
| 273 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 274 |
+
|
| 275 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
'latent_indices': latent_indices,
|
| 279 |
+
'clean_latents': clean_latents,
|
| 280 |
+
'clean_latents_2x': clean_latents_2x,
|
| 281 |
+
'clean_latents_4x': clean_latents_4x,
|
| 282 |
+
'clean_latent_indices': clean_latent_indices,
|
| 283 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 284 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 285 |
+
'camera_embedding': combined_camera,
|
| 286 |
+
'current_length': T,
|
| 287 |
+
'next_length': T + target_frames_to_generate
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def inference_openx_framepack_sliding_window(
|
| 291 |
+
condition_pth_path,
|
| 292 |
+
dit_path,
|
| 293 |
+
output_path="openx_results/output_openx_framepack_sliding.mp4",
|
| 294 |
+
start_frame=0,
|
| 295 |
+
initial_condition_frames=8,
|
| 296 |
+
frames_per_generation=4,
|
| 297 |
+
total_frames_to_generate=32,
|
| 298 |
+
max_history_frames=49,
|
| 299 |
+
device="cuda",
|
| 300 |
+
prompt="A video of robotic manipulation task with camera movement",
|
| 301 |
+
use_real_poses=True,
|
| 302 |
+
# CFG参数
|
| 303 |
+
use_camera_cfg=True,
|
| 304 |
+
camera_guidance_scale=2.0,
|
| 305 |
+
text_guidance_scale=1.0
|
| 306 |
+
):
|
| 307 |
+
"""
|
| 308 |
+
OpenX FramePack滑动窗口视频生成
|
| 309 |
+
"""
|
| 310 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 311 |
+
print(f"🔧 OpenX FramePack滑动窗口生成开始...")
|
| 312 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 313 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 314 |
+
|
| 315 |
+
# 1. 模型初始化
|
| 316 |
+
replace_dit_model_in_manager()
|
| 317 |
+
|
| 318 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 319 |
+
model_manager.load_models([
|
| 320 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 321 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 322 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 323 |
+
])
|
| 324 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 325 |
+
|
| 326 |
+
# 2. 添加camera编码器
|
| 327 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 328 |
+
for block in pipe.dit.blocks:
|
| 329 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 330 |
+
block.projector = nn.Linear(dim, dim)
|
| 331 |
+
block.cam_encoder.weight.data.zero_()
|
| 332 |
+
block.cam_encoder.bias.data.zero_()
|
| 333 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 334 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 335 |
+
|
| 336 |
+
# 3. 添加FramePack组件
|
| 337 |
+
add_framepack_components(pipe.dit)
|
| 338 |
+
|
| 339 |
+
# 4. 加载训练好的权重
|
| 340 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 341 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 342 |
+
pipe = pipe.to(device)
|
| 343 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 344 |
+
|
| 345 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 346 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 347 |
+
|
| 348 |
+
pipe.scheduler.set_timesteps(50)
|
| 349 |
+
|
| 350 |
+
# 5. 加载初始条件
|
| 351 |
+
print("Loading initial condition frames...")
|
| 352 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 353 |
+
condition_pth_path,
|
| 354 |
+
start_frame=start_frame,
|
| 355 |
+
num_frames=initial_condition_frames
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# 空间裁剪(适配OpenX数据尺寸)
|
| 359 |
+
target_height, target_width = 60, 104
|
| 360 |
+
C, T, H, W = initial_latents.shape
|
| 361 |
+
|
| 362 |
+
if H > target_height or W > target_width:
|
| 363 |
+
h_start = (H - target_height) // 2
|
| 364 |
+
w_start = (W - target_width) // 2
|
| 365 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 366 |
+
H, W = target_height, target_width
|
| 367 |
+
|
| 368 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 369 |
+
|
| 370 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 371 |
+
|
| 372 |
+
# 6. 编码prompt - 支持CFG
|
| 373 |
+
if text_guidance_scale > 1.0:
|
| 374 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 375 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 376 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 377 |
+
else:
|
| 378 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 379 |
+
prompt_emb_neg = None
|
| 380 |
+
print("不使用Text CFG")
|
| 381 |
+
|
| 382 |
+
# 7. 预生成完整的camera embedding序列
|
| 383 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 384 |
+
encoded_data.get('cam_emb', None),
|
| 385 |
+
0,
|
| 386 |
+
max_history_frames,
|
| 387 |
+
0,
|
| 388 |
+
0,
|
| 389 |
+
use_real_poses=use_real_poses
|
| 390 |
+
).to(device, dtype=model_dtype)
|
| 391 |
+
|
| 392 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 393 |
+
|
| 394 |
+
# 8. 为Camera CFG创建无条件的camera embedding
|
| 395 |
+
if use_camera_cfg:
|
| 396 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 397 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 398 |
+
|
| 399 |
+
# 9. 滑动窗口生成循环
|
| 400 |
+
total_generated = 0
|
| 401 |
+
all_generated_frames = []
|
| 402 |
+
|
| 403 |
+
while total_generated < total_frames_to_generate:
|
| 404 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 405 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 406 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 407 |
+
|
| 408 |
+
# FramePack数据准备 - OpenX版本
|
| 409 |
+
framepack_data = prepare_framepack_sliding_window_with_camera(
|
| 410 |
+
history_latents,
|
| 411 |
+
current_generation,
|
| 412 |
+
camera_embedding_full,
|
| 413 |
+
start_frame,
|
| 414 |
+
max_history_frames
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# 准备输入
|
| 418 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 419 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 420 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 421 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 422 |
+
|
| 423 |
+
# 为CFG准备无条件camera embedding
|
| 424 |
+
if use_camera_cfg:
|
| 425 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 426 |
+
|
| 427 |
+
# 索引处理
|
| 428 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 429 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 430 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 431 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 432 |
+
|
| 433 |
+
# 初始化要生成的latents
|
| 434 |
+
new_latents = torch.randn(
|
| 435 |
+
1, C, current_generation, H, W,
|
| 436 |
+
device=device, dtype=model_dtype
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 440 |
+
|
| 441 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 442 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 443 |
+
|
| 444 |
+
# 去噪循环 - 支持CFG
|
| 445 |
+
timesteps = pipe.scheduler.timesteps
|
| 446 |
+
|
| 447 |
+
for i, timestep in enumerate(timesteps):
|
| 448 |
+
if i % 10 == 0:
|
| 449 |
+
print(f" 去噪步骤 {i}/{len(timesteps)}")
|
| 450 |
+
|
| 451 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 452 |
+
|
| 453 |
+
with torch.no_grad():
|
| 454 |
+
# 正向预测(带条件)
|
| 455 |
+
noise_pred_pos = pipe.dit(
|
| 456 |
+
new_latents,
|
| 457 |
+
timestep=timestep_tensor,
|
| 458 |
+
cam_emb=camera_embedding,
|
| 459 |
+
latent_indices=latent_indices,
|
| 460 |
+
clean_latents=clean_latents,
|
| 461 |
+
clean_latent_indices=clean_latent_indices,
|
| 462 |
+
clean_latents_2x=clean_latents_2x,
|
| 463 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 464 |
+
clean_latents_4x=clean_latents_4x,
|
| 465 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 466 |
+
**prompt_emb_pos,
|
| 467 |
+
**extra_input
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# CFG处理
|
| 471 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 472 |
+
# 无条件预测(无camera条件)
|
| 473 |
+
noise_pred_uncond = pipe.dit(
|
| 474 |
+
new_latents,
|
| 475 |
+
timestep=timestep_tensor,
|
| 476 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 477 |
+
latent_indices=latent_indices,
|
| 478 |
+
clean_latents=clean_latents,
|
| 479 |
+
clean_latent_indices=clean_latent_indices,
|
| 480 |
+
clean_latents_2x=clean_latents_2x,
|
| 481 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 482 |
+
clean_latents_4x=clean_latents_4x,
|
| 483 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 484 |
+
**prompt_emb_pos,
|
| 485 |
+
**extra_input
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Camera CFG
|
| 489 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
|
| 490 |
+
else:
|
| 491 |
+
noise_pred = noise_pred_pos
|
| 492 |
+
|
| 493 |
+
# Text CFG
|
| 494 |
+
if prompt_emb_neg is not None and text_guidance_scale > 1.0:
|
| 495 |
+
noise_pred_text_uncond = pipe.dit(
|
| 496 |
+
new_latents,
|
| 497 |
+
timestep=timestep_tensor,
|
| 498 |
+
cam_emb=camera_embedding,
|
| 499 |
+
latent_indices=latent_indices,
|
| 500 |
+
clean_latents=clean_latents,
|
| 501 |
+
clean_latent_indices=clean_latent_indices,
|
| 502 |
+
clean_latents_2x=clean_latents_2x,
|
| 503 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 504 |
+
clean_latents_4x=clean_latents_4x,
|
| 505 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 506 |
+
**prompt_emb_neg,
|
| 507 |
+
**extra_input
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Text CFG
|
| 511 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 512 |
+
|
| 513 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 514 |
+
|
| 515 |
+
# 更新历史
|
| 516 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 517 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 518 |
+
|
| 519 |
+
# 维护滑动窗口
|
| 520 |
+
if history_latents.shape[1] > max_history_frames:
|
| 521 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 522 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 523 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 524 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 525 |
+
|
| 526 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 527 |
+
|
| 528 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 529 |
+
total_generated += current_generation
|
| 530 |
+
|
| 531 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 532 |
+
|
| 533 |
+
# 10. 解码和保存
|
| 534 |
+
print("\n🔧 解码生成的视频...")
|
| 535 |
+
|
| 536 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 537 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 538 |
+
|
| 539 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 540 |
+
|
| 541 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 542 |
+
|
| 543 |
+
print(f"Saving video to {output_path}")
|
| 544 |
+
|
| 545 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 546 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 547 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 548 |
+
|
| 549 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 550 |
+
for frame in video_np:
|
| 551 |
+
writer.append_data(frame)
|
| 552 |
+
|
| 553 |
+
print(f"🔧 OpenX FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 554 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 555 |
+
|
| 556 |
+
def main():
|
| 557 |
+
parser = argparse.ArgumentParser(description="OpenX FramePack滑动窗口视频生成")
|
| 558 |
+
|
| 559 |
+
# 基础参数
|
| 560 |
+
parser.add_argument("--condition_pth", type=str,
|
| 561 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth",
|
| 562 |
+
help="输入编码视频路径")
|
| 563 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 564 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 565 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 566 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 567 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 568 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 569 |
+
parser.add_argument("--dit_path", type=str,
|
| 570 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step2000.ckpt",
|
| 571 |
+
help="训练好的模型权重路径")
|
| 572 |
+
parser.add_argument("--output_path", type=str,
|
| 573 |
+
default='openx_results/output_openx_framepack_sliding.mp4')
|
| 574 |
+
parser.add_argument("--prompt", type=str,
|
| 575 |
+
default="A video of robotic manipulation task with camera movement")
|
| 576 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 577 |
+
|
| 578 |
+
# CFG参数
|
| 579 |
+
parser.add_argument("--use_camera_cfg", action="store_true", default=True,
|
| 580 |
+
help="使用Camera CFG")
|
| 581 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 582 |
+
help="Camera guidance scale for CFG")
|
| 583 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 584 |
+
help="Text guidance scale for CFG")
|
| 585 |
+
|
| 586 |
+
args = parser.parse_args()
|
| 587 |
+
|
| 588 |
+
print(f"🔧 OpenX FramePack CFG生成设置:")
|
| 589 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 590 |
+
if args.use_camera_cfg:
|
| 591 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 592 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 593 |
+
print(f"OpenX特有特性: camera间隔为4帧,适用于机器人操作任务")
|
| 594 |
+
|
| 595 |
+
inference_openx_framepack_sliding_window(
|
| 596 |
+
condition_pth_path=args.condition_pth,
|
| 597 |
+
dit_path=args.dit_path,
|
| 598 |
+
output_path=args.output_path,
|
| 599 |
+
start_frame=args.start_frame,
|
| 600 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 601 |
+
frames_per_generation=args.frames_per_generation,
|
| 602 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 603 |
+
max_history_frames=args.max_history_frames,
|
| 604 |
+
device=args.device,
|
| 605 |
+
prompt=args.prompt,
|
| 606 |
+
use_real_poses=args.use_real_poses,
|
| 607 |
+
# CFG参数
|
| 608 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 609 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 610 |
+
text_guidance_scale=args.text_guidance_scale
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
main()
|
scripts/infer_origin.py
ADDED
|
@@ -0,0 +1,1108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 15 |
+
"""
|
| 16 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 17 |
+
|
| 18 |
+
参数:
|
| 19 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 20 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 21 |
+
|
| 22 |
+
返回:
|
| 23 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 24 |
+
"""
|
| 25 |
+
# 分离平移向量和四元数
|
| 26 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 27 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 28 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 29 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 30 |
+
|
| 31 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 32 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 33 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 34 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 35 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 36 |
+
|
| 37 |
+
# 2. 计算相对平移向量 t_rel
|
| 38 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 39 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 40 |
+
|
| 41 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 42 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 43 |
+
|
| 44 |
+
return relative_matrix
|
| 45 |
+
|
| 46 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 47 |
+
"""从pth文件加载预编码的视频数据"""
|
| 48 |
+
print(f"Loading encoded video from {pth_path}")
|
| 49 |
+
|
| 50 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 51 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 52 |
+
|
| 53 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 54 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 55 |
+
|
| 56 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 57 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 58 |
+
|
| 59 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 60 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 61 |
+
|
| 62 |
+
return condition_latents, encoded_data
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 66 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 67 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 68 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 69 |
+
|
| 70 |
+
if use_torch:
|
| 71 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 72 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 73 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 74 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 75 |
+
|
| 76 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 77 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 78 |
+
else:
|
| 79 |
+
if not isinstance(pose_a, np.ndarray):
|
| 80 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 81 |
+
if not isinstance(pose_b, np.ndarray):
|
| 82 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 85 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 86 |
+
|
| 87 |
+
return relative_pose
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def replace_dit_model_in_manager():
|
| 91 |
+
"""替换DiT模型类为MoE版本"""
|
| 92 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 93 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 94 |
+
|
| 95 |
+
for i, config in enumerate(model_loader_configs):
|
| 96 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 97 |
+
|
| 98 |
+
if 'wan_video_dit' in model_names:
|
| 99 |
+
new_model_names = []
|
| 100 |
+
new_model_classes = []
|
| 101 |
+
|
| 102 |
+
for name, cls in zip(model_names, model_classes):
|
| 103 |
+
if name == 'wan_video_dit':
|
| 104 |
+
new_model_names.append(name)
|
| 105 |
+
new_model_classes.append(WanModelMoe)
|
| 106 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 107 |
+
else:
|
| 108 |
+
new_model_names.append(name)
|
| 109 |
+
new_model_classes.append(cls)
|
| 110 |
+
|
| 111 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def add_framepack_components(dit_model):
|
| 115 |
+
"""添加FramePack相关组件"""
|
| 116 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 117 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 118 |
+
|
| 119 |
+
class CleanXEmbedder(nn.Module):
|
| 120 |
+
def __init__(self, inner_dim):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 123 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 124 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 125 |
+
|
| 126 |
+
def forward(self, x, scale="1x"):
|
| 127 |
+
if scale == "1x":
|
| 128 |
+
x = x.to(self.proj.weight.dtype)
|
| 129 |
+
return self.proj(x)
|
| 130 |
+
elif scale == "2x":
|
| 131 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 132 |
+
return self.proj_2x(x)
|
| 133 |
+
elif scale == "4x":
|
| 134 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 135 |
+
return self.proj_4x(x)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 138 |
+
|
| 139 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 140 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 141 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 142 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def add_moe_components(dit_model, moe_config):
|
| 146 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 147 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 148 |
+
dit_model.moe_config = moe_config
|
| 149 |
+
print("✅ 添加了MoE配置到模型")
|
| 150 |
+
dit_model.top_k = moe_config.get("top_k", 1)
|
| 151 |
+
|
| 152 |
+
# 为每个block动态添加MoE组件
|
| 153 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 154 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 155 |
+
num_experts = moe_config.get("num_experts", 4)
|
| 156 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 157 |
+
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 158 |
+
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 159 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 160 |
+
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
for i, block in enumerate(dit_model.blocks):
|
| 164 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 165 |
+
block.moe = MultiModalMoE(
|
| 166 |
+
unified_dim=unified_dim,
|
| 167 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 168 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 169 |
+
top_k=moe_config.get("top_k", 2)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True,direction="left"):
|
| 176 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 177 |
+
time_compression_ratio = 4
|
| 178 |
+
|
| 179 |
+
# 计算FramePack实际需要的camera帧数
|
| 180 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 181 |
+
|
| 182 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 183 |
+
print("🔧 使用真实Sekai camera数据")
|
| 184 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 185 |
+
|
| 186 |
+
# 确保生成足够长的camera序列
|
| 187 |
+
max_needed_frames = max(
|
| 188 |
+
start_frame + current_history_length + new_frames,
|
| 189 |
+
framepack_needed_frames,
|
| 190 |
+
30
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 194 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 195 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 196 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 197 |
+
|
| 198 |
+
relative_poses = []
|
| 199 |
+
for i in range(max_needed_frames):
|
| 200 |
+
# 计算当前帧在原始序列中的位置
|
| 201 |
+
frame_idx = i * time_compression_ratio
|
| 202 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 203 |
+
|
| 204 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 205 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 206 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 207 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 208 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 209 |
+
else:
|
| 210 |
+
# 超出范围,使用零运动
|
| 211 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 212 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 213 |
+
|
| 214 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 215 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 216 |
+
|
| 217 |
+
# 创建对应长度的mask序列
|
| 218 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 219 |
+
# 从start_frame到current_history_length标记为condition
|
| 220 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 221 |
+
mask[start_frame:condition_end] = 1.0
|
| 222 |
+
|
| 223 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 224 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 225 |
+
return camera_embedding.to(torch.bfloat16)
|
| 226 |
+
|
| 227 |
+
else:
|
| 228 |
+
if direction=="left":
|
| 229 |
+
print("-----Left-------")
|
| 230 |
+
|
| 231 |
+
max_needed_frames = max(
|
| 232 |
+
start_frame + current_history_length + new_frames,
|
| 233 |
+
framepack_needed_frames,
|
| 234 |
+
30
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 238 |
+
relative_poses = []
|
| 239 |
+
for i in range(max_needed_frames):
|
| 240 |
+
# 持续左转运动模式
|
| 241 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 242 |
+
forward_speed = 0.05 # 每帧前进距离
|
| 243 |
+
|
| 244 |
+
pose = np.eye(4, dtype=np.float32)
|
| 245 |
+
|
| 246 |
+
# 旋转矩阵(绕Y轴左转)
|
| 247 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 248 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 249 |
+
|
| 250 |
+
pose[0, 0] = cos_yaw
|
| 251 |
+
pose[0, 2] = sin_yaw
|
| 252 |
+
pose[2, 0] = -sin_yaw
|
| 253 |
+
pose[2, 2] = cos_yaw
|
| 254 |
+
|
| 255 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 256 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 257 |
+
|
| 258 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 259 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 260 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 261 |
+
|
| 262 |
+
relative_pose = pose[:3, :]
|
| 263 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 264 |
+
|
| 265 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 266 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 267 |
+
|
| 268 |
+
# 创建对应长度的mask序列
|
| 269 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 270 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 271 |
+
mask[start_frame:condition_end] = 1.0
|
| 272 |
+
|
| 273 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 274 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 275 |
+
return camera_embedding.to(torch.bfloat16)
|
| 276 |
+
elif direction=="right":
|
| 277 |
+
print("------------Right----------")
|
| 278 |
+
|
| 279 |
+
max_needed_frames = max(
|
| 280 |
+
start_frame + current_history_length + new_frames,
|
| 281 |
+
framepack_needed_frames,
|
| 282 |
+
30
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 286 |
+
relative_poses = []
|
| 287 |
+
for i in range(max_needed_frames):
|
| 288 |
+
# 持续左转运动模式
|
| 289 |
+
yaw_per_frame = -0.00 # 每帧左转(正角度表示左转)
|
| 290 |
+
forward_speed = 0.1 # 每帧前进距离
|
| 291 |
+
|
| 292 |
+
pose = np.eye(4, dtype=np.float32)
|
| 293 |
+
|
| 294 |
+
# 旋转矩阵(绕Y轴左转)
|
| 295 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 296 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 297 |
+
|
| 298 |
+
pose[0, 0] = cos_yaw
|
| 299 |
+
pose[0, 2] = sin_yaw
|
| 300 |
+
pose[2, 0] = -sin_yaw
|
| 301 |
+
pose[2, 2] = cos_yaw
|
| 302 |
+
|
| 303 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 304 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 305 |
+
|
| 306 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 307 |
+
radius_drift = 0.000 # 向圆心的轻微漂移
|
| 308 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 309 |
+
|
| 310 |
+
relative_pose = pose[:3, :]
|
| 311 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 312 |
+
|
| 313 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 314 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 315 |
+
|
| 316 |
+
# 创建对应长度的mask序列
|
| 317 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 318 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 319 |
+
mask[start_frame:condition_end] = 1.0
|
| 320 |
+
|
| 321 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 322 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 323 |
+
return camera_embedding.to(torch.bfloat16)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 327 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 328 |
+
time_compression_ratio = 4
|
| 329 |
+
|
| 330 |
+
# 计算FramePack实际需要的camera帧数
|
| 331 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 332 |
+
|
| 333 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 334 |
+
print("🔧 使用OpenX真实camera数据")
|
| 335 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 336 |
+
|
| 337 |
+
# 确保生成足够长的camera序列
|
| 338 |
+
max_needed_frames = max(
|
| 339 |
+
start_frame + current_history_length + new_frames,
|
| 340 |
+
framepack_needed_frames,
|
| 341 |
+
30
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 345 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 346 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 347 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 348 |
+
|
| 349 |
+
relative_poses = []
|
| 350 |
+
for i in range(max_needed_frames):
|
| 351 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 352 |
+
frame_idx = i * time_compression_ratio
|
| 353 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 354 |
+
|
| 355 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 356 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 357 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 358 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 359 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 360 |
+
else:
|
| 361 |
+
# 超出范围,使用零运动
|
| 362 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 363 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 364 |
+
|
| 365 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 366 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 367 |
+
|
| 368 |
+
# 创建对应长度的mask序列
|
| 369 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 370 |
+
# 从start_frame到current_history_length标记为condition
|
| 371 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 372 |
+
mask[start_frame:condition_end] = 1.0
|
| 373 |
+
|
| 374 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 375 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 376 |
+
return camera_embedding.to(torch.bfloat16)
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
print("🔧 使用OpenX合成camera数据")
|
| 380 |
+
|
| 381 |
+
max_needed_frames = max(
|
| 382 |
+
start_frame + current_history_length + new_frames,
|
| 383 |
+
framepack_needed_frames,
|
| 384 |
+
30
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 388 |
+
relative_poses = []
|
| 389 |
+
for i in range(max_needed_frames):
|
| 390 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 391 |
+
# 模拟机器人手臂的精细操作运动
|
| 392 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 393 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 394 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 395 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 396 |
+
|
| 397 |
+
pose = np.eye(4, dtype=np.float32)
|
| 398 |
+
|
| 399 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 400 |
+
# 绕X轴旋转(roll)
|
| 401 |
+
cos_roll = np.cos(roll_per_frame)
|
| 402 |
+
sin_roll = np.sin(roll_per_frame)
|
| 403 |
+
# 绕Y轴旋转(pitch)
|
| 404 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 405 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 406 |
+
# 绕Z轴旋转(yaw)
|
| 407 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 408 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 409 |
+
|
| 410 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 411 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 412 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 413 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 414 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 415 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 416 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 417 |
+
pose[2, 0] = -sin_pitch
|
| 418 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 419 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 420 |
+
|
| 421 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 422 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 423 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 424 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 425 |
+
|
| 426 |
+
relative_pose = pose[:3, :]
|
| 427 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 428 |
+
|
| 429 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 430 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 431 |
+
|
| 432 |
+
# 创建对应长度的mask序列
|
| 433 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 434 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 435 |
+
mask[start_frame:condition_end] = 1.0
|
| 436 |
+
|
| 437 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 438 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 439 |
+
return camera_embedding.to(torch.bfloat16)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 443 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 444 |
+
time_compression_ratio = 4
|
| 445 |
+
|
| 446 |
+
# 计算FramePack实际需要的camera��数
|
| 447 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 448 |
+
|
| 449 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 450 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 451 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 452 |
+
|
| 453 |
+
if len(keyframe_poses) == 0:
|
| 454 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 455 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 456 |
+
|
| 457 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 458 |
+
|
| 459 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 460 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 461 |
+
mask[start_frame:condition_end] = 1.0
|
| 462 |
+
|
| 463 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 464 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 465 |
+
return camera_embedding.to(torch.bfloat16)
|
| 466 |
+
|
| 467 |
+
# 使用第一个pose作为参考
|
| 468 |
+
reference_pose = keyframe_poses[0]
|
| 469 |
+
|
| 470 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 471 |
+
|
| 472 |
+
pose_vecs = []
|
| 473 |
+
for i in range(max_needed_frames):
|
| 474 |
+
if i < len(keyframe_poses):
|
| 475 |
+
current_pose = keyframe_poses[i]
|
| 476 |
+
|
| 477 |
+
# 计算相对位移
|
| 478 |
+
translation = torch.tensor(
|
| 479 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 480 |
+
dtype=torch.float32
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# 计算相对旋转(简化版本)
|
| 484 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 485 |
+
|
| 486 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 487 |
+
else:
|
| 488 |
+
# 超出范围,使用零pose
|
| 489 |
+
pose_vec = torch.cat([
|
| 490 |
+
torch.zeros(3, dtype=torch.float32),
|
| 491 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 492 |
+
], dim=0) # [7D]
|
| 493 |
+
|
| 494 |
+
pose_vecs.append(pose_vec)
|
| 495 |
+
|
| 496 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 497 |
+
|
| 498 |
+
# 创建mask
|
| 499 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 500 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 501 |
+
mask[start_frame:condition_end] = 1.0
|
| 502 |
+
|
| 503 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 504 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 505 |
+
return camera_embedding.to(torch.bfloat16)
|
| 506 |
+
|
| 507 |
+
else:
|
| 508 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 509 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 510 |
+
|
| 511 |
+
# 创建合成运动序列
|
| 512 |
+
pose_vecs = []
|
| 513 |
+
for i in range(max_needed_frames):
|
| 514 |
+
# 左转运动模式 - 类似城市驾驶中的左转弯
|
| 515 |
+
angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
|
| 516 |
+
radius = 15.0 # 较大的转弯半径,更符合汽车转弯
|
| 517 |
+
|
| 518 |
+
# 计算圆弧轨迹上的位置
|
| 519 |
+
x = radius * np.sin(angle)
|
| 520 |
+
y = 0.0 # 保持水平面运动
|
| 521 |
+
z = radius * (1 - np.cos(angle))
|
| 522 |
+
|
| 523 |
+
translation = torch.tensor([x, y, z], dtype=torch.float32)
|
| 524 |
+
|
| 525 |
+
# 车辆朝向 - 始终沿着轨迹切线方向
|
| 526 |
+
yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
|
| 527 |
+
# 四元数表示绕Y轴的旋转
|
| 528 |
+
rotation = torch.tensor([
|
| 529 |
+
np.cos(yaw/2), # w (实部)
|
| 530 |
+
0.0, # x
|
| 531 |
+
0.0, # y
|
| 532 |
+
np.sin(yaw/2) # z (虚部,绕Y轴)
|
| 533 |
+
], dtype=torch.float32)
|
| 534 |
+
|
| 535 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
|
| 536 |
+
pose_vecs.append(pose_vec)
|
| 537 |
+
|
| 538 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 539 |
+
|
| 540 |
+
# 创建mask
|
| 541 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 542 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 543 |
+
mask[start_frame:condition_end] = 1.0
|
| 544 |
+
|
| 545 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 546 |
+
print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
|
| 547 |
+
return camera_embedding.to(torch.bfloat16)
|
| 548 |
+
|
| 549 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 550 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 551 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 552 |
+
C, T, H, W = history_latents.shape
|
| 553 |
+
|
| 554 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 555 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 556 |
+
indices = torch.arange(0, total_indices_length)
|
| 557 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 558 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 559 |
+
indices.split(split_sizes, dim=0)
|
| 560 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 561 |
+
|
| 562 |
+
# 检查camera长度是否足够
|
| 563 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 564 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 565 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 566 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 567 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 568 |
+
|
| 569 |
+
# 从完整camera序列中选取对应部分
|
| 570 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 571 |
+
|
| 572 |
+
# 根据当前history length重新设置mask
|
| 573 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 574 |
+
|
| 575 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 576 |
+
if T > 0:
|
| 577 |
+
available_frames = min(T, 19)
|
| 578 |
+
start_pos = 19 - available_frames
|
| 579 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 580 |
+
|
| 581 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 582 |
+
print(f" - 历史帧数: {T}")
|
| 583 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 584 |
+
print(f" - 模态类型: {modality_type}")
|
| 585 |
+
|
| 586 |
+
# 处理latents
|
| 587 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 588 |
+
|
| 589 |
+
if T > 0:
|
| 590 |
+
available_frames = min(T, 19)
|
| 591 |
+
start_pos = 19 - available_frames
|
| 592 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 593 |
+
|
| 594 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 595 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 596 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 597 |
+
|
| 598 |
+
if T > 0:
|
| 599 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 600 |
+
else:
|
| 601 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 602 |
+
|
| 603 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
'latent_indices': latent_indices,
|
| 607 |
+
'clean_latents': clean_latents,
|
| 608 |
+
'clean_latents_2x': clean_latents_2x,
|
| 609 |
+
'clean_latents_4x': clean_latents_4x,
|
| 610 |
+
'clean_latent_indices': clean_latent_indices,
|
| 611 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 612 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 613 |
+
'camera_embedding': combined_camera,
|
| 614 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 615 |
+
'current_length': T,
|
| 616 |
+
'next_length': T + target_frames_to_generate
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def inference_moe_framepack_sliding_window(
|
| 621 |
+
condition_pth_path,
|
| 622 |
+
dit_path,
|
| 623 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 624 |
+
start_frame=0,
|
| 625 |
+
initial_condition_frames=8,
|
| 626 |
+
frames_per_generation=4,
|
| 627 |
+
total_frames_to_generate=32,
|
| 628 |
+
max_history_frames=49,
|
| 629 |
+
device="cuda",
|
| 630 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 631 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 632 |
+
use_real_poses=True,
|
| 633 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 634 |
+
# CFG参数
|
| 635 |
+
use_camera_cfg=True,
|
| 636 |
+
camera_guidance_scale=2.0,
|
| 637 |
+
text_guidance_scale=1.0,
|
| 638 |
+
# MoE参数
|
| 639 |
+
moe_num_experts=4,
|
| 640 |
+
moe_top_k=2,
|
| 641 |
+
moe_hidden_dim=None,
|
| 642 |
+
direction="left",
|
| 643 |
+
use_gt_prompt=True
|
| 644 |
+
):
|
| 645 |
+
"""
|
| 646 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 647 |
+
"""
|
| 648 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 649 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 650 |
+
print(f"模态类型: {modality_type}")
|
| 651 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 652 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 653 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 654 |
+
|
| 655 |
+
# 1. 模型初始化
|
| 656 |
+
replace_dit_model_in_manager()
|
| 657 |
+
|
| 658 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 659 |
+
model_manager.load_models([
|
| 660 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 661 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 662 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 663 |
+
])
|
| 664 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 665 |
+
|
| 666 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 667 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 668 |
+
for block in pipe.dit.blocks:
|
| 669 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 670 |
+
block.projector = nn.Linear(dim, dim)
|
| 671 |
+
block.cam_encoder.weight.data.zero_()
|
| 672 |
+
block.cam_encoder.bias.data.zero_()
|
| 673 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 674 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 675 |
+
|
| 676 |
+
# 3. 添加FramePack组件
|
| 677 |
+
add_framepack_components(pipe.dit)
|
| 678 |
+
|
| 679 |
+
# 4. 添加MoE组件
|
| 680 |
+
moe_config = {
|
| 681 |
+
"num_experts": moe_num_experts,
|
| 682 |
+
"top_k": moe_top_k,
|
| 683 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 684 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 685 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 686 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 687 |
+
}
|
| 688 |
+
add_moe_components(pipe.dit, moe_config)
|
| 689 |
+
|
| 690 |
+
# 5. 加载训练好的权重
|
| 691 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 692 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 693 |
+
pipe = pipe.to(device)
|
| 694 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 695 |
+
|
| 696 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 697 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 698 |
+
|
| 699 |
+
pipe.scheduler.set_timesteps(50)
|
| 700 |
+
|
| 701 |
+
# 6. 加载初始条件
|
| 702 |
+
print("Loading initial condition frames...")
|
| 703 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 704 |
+
condition_pth_path,
|
| 705 |
+
start_frame=start_frame,
|
| 706 |
+
num_frames=initial_condition_frames
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# 空间裁剪
|
| 710 |
+
target_height, target_width = 60, 104
|
| 711 |
+
C, T, H, W = initial_latents.shape
|
| 712 |
+
|
| 713 |
+
if H > target_height or W > target_width:
|
| 714 |
+
h_start = (H - target_height) // 2
|
| 715 |
+
w_start = (W - target_width) // 2
|
| 716 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 717 |
+
H, W = target_height, target_width
|
| 718 |
+
|
| 719 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 720 |
+
|
| 721 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 722 |
+
|
| 723 |
+
# 7. 编码prompt - 支持CFG
|
| 724 |
+
if use_gt_prompt and 'prompt_emb' in encoded_data:
|
| 725 |
+
print("✅ 使用预编码的GT prompt embedding")
|
| 726 |
+
prompt_emb_pos = encoded_data['prompt_emb']
|
| 727 |
+
# 将prompt_emb移到正确的设备和数据类型
|
| 728 |
+
if 'context' in prompt_emb_pos:
|
| 729 |
+
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
|
| 730 |
+
if 'context_mask' in prompt_emb_pos:
|
| 731 |
+
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
|
| 732 |
+
|
| 733 |
+
# 如果使用Text CFG,生成负向prompt
|
| 734 |
+
if text_guidance_scale > 1.0:
|
| 735 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 736 |
+
print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
|
| 737 |
+
else:
|
| 738 |
+
prompt_emb_neg = None
|
| 739 |
+
print("不使用Text CFG")
|
| 740 |
+
|
| 741 |
+
# 🔧 打印GT prompt文本(如果有)
|
| 742 |
+
if 'prompt' in encoded_data['prompt_emb']:
|
| 743 |
+
gt_prompt_text = encoded_data['prompt_emb']['prompt']
|
| 744 |
+
print(f"📝 GT Prompt文本: {gt_prompt_text}")
|
| 745 |
+
else:
|
| 746 |
+
# 使用传入的prompt参数重新编码
|
| 747 |
+
print(f"🔄 重新编码prompt: {prompt}")
|
| 748 |
+
if text_guidance_scale > 1.0:
|
| 749 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 750 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 751 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 752 |
+
else:
|
| 753 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 754 |
+
prompt_emb_neg = None
|
| 755 |
+
print("不使用Text CFG")
|
| 756 |
+
|
| 757 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 758 |
+
scene_info = None
|
| 759 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 760 |
+
with open(scene_info_path, 'r') as f:
|
| 761 |
+
scene_info = json.load(f)
|
| 762 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 763 |
+
|
| 764 |
+
# 9. 预生成完整的camera embedding序列
|
| 765 |
+
if modality_type == "sekai":
|
| 766 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 767 |
+
encoded_data.get('cam_emb', None),
|
| 768 |
+
0,
|
| 769 |
+
max_history_frames,
|
| 770 |
+
0,
|
| 771 |
+
0,
|
| 772 |
+
use_real_poses=use_real_poses,
|
| 773 |
+
direction=direction
|
| 774 |
+
).to(device, dtype=model_dtype)
|
| 775 |
+
elif modality_type == "nuscenes":
|
| 776 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 777 |
+
scene_info,
|
| 778 |
+
0,
|
| 779 |
+
max_history_frames,
|
| 780 |
+
0
|
| 781 |
+
).to(device, dtype=model_dtype)
|
| 782 |
+
elif modality_type == "openx":
|
| 783 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 784 |
+
encoded_data,
|
| 785 |
+
0,
|
| 786 |
+
max_history_frames,
|
| 787 |
+
0,
|
| 788 |
+
use_real_poses=use_real_poses
|
| 789 |
+
).to(device, dtype=model_dtype)
|
| 790 |
+
else:
|
| 791 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 792 |
+
|
| 793 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 794 |
+
|
| 795 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 796 |
+
if use_camera_cfg:
|
| 797 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 798 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 799 |
+
|
| 800 |
+
# 11. 滑动窗口生成循环
|
| 801 |
+
total_generated = 0
|
| 802 |
+
all_generated_frames = []
|
| 803 |
+
|
| 804 |
+
while total_generated < total_frames_to_generate:
|
| 805 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 806 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 807 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 808 |
+
|
| 809 |
+
# FramePack数据准备 - MoE版本
|
| 810 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 811 |
+
history_latents,
|
| 812 |
+
current_generation,
|
| 813 |
+
camera_embedding_full,
|
| 814 |
+
start_frame,
|
| 815 |
+
modality_type,
|
| 816 |
+
max_history_frames
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# 准备输入
|
| 820 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 821 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 822 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 823 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 824 |
+
|
| 825 |
+
# 准备modality_inputs
|
| 826 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 827 |
+
|
| 828 |
+
# 为CFG准备无条件camera embedding
|
| 829 |
+
if use_camera_cfg:
|
| 830 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 831 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 832 |
+
|
| 833 |
+
# 索引处理
|
| 834 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 835 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 836 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 837 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 838 |
+
|
| 839 |
+
# 初始化要生成的latents
|
| 840 |
+
new_latents = torch.randn(
|
| 841 |
+
1, C, current_generation, H, W,
|
| 842 |
+
device=device, dtype=model_dtype
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 846 |
+
|
| 847 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 848 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 849 |
+
|
| 850 |
+
# 去噪循环 - 支持CFG
|
| 851 |
+
timesteps = pipe.scheduler.timesteps
|
| 852 |
+
|
| 853 |
+
for i, timestep in enumerate(timesteps):
|
| 854 |
+
if i % 10 == 0:
|
| 855 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 856 |
+
|
| 857 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 858 |
+
|
| 859 |
+
with torch.no_grad():
|
| 860 |
+
# CFG推理
|
| 861 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 862 |
+
# 条件预测(有camera)
|
| 863 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 864 |
+
new_latents,
|
| 865 |
+
timestep=timestep_tensor,
|
| 866 |
+
cam_emb=camera_embedding,
|
| 867 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 868 |
+
latent_indices=latent_indices,
|
| 869 |
+
clean_latents=clean_latents,
|
| 870 |
+
clean_latent_indices=clean_latent_indices,
|
| 871 |
+
clean_latents_2x=clean_latents_2x,
|
| 872 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 873 |
+
clean_latents_4x=clean_latents_4x,
|
| 874 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 875 |
+
**prompt_emb_pos,
|
| 876 |
+
**extra_input
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# 无条件预测(无camera)
|
| 880 |
+
noise_pred_uncond, moe_loess = pipe.dit(
|
| 881 |
+
new_latents,
|
| 882 |
+
timestep=timestep_tensor,
|
| 883 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 884 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 885 |
+
latent_indices=latent_indices,
|
| 886 |
+
clean_latents=clean_latents,
|
| 887 |
+
clean_latent_indices=clean_latent_indices,
|
| 888 |
+
clean_latents_2x=clean_latents_2x,
|
| 889 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 890 |
+
clean_latents_4x=clean_latents_4x,
|
| 891 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 892 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 893 |
+
**extra_input
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# Camera CFG
|
| 897 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 898 |
+
|
| 899 |
+
# 如果同时使用Text CFG
|
| 900 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 901 |
+
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 902 |
+
new_latents,
|
| 903 |
+
timestep=timestep_tensor,
|
| 904 |
+
cam_emb=camera_embedding,
|
| 905 |
+
modality_inputs=modality_inputs,
|
| 906 |
+
latent_indices=latent_indices,
|
| 907 |
+
clean_latents=clean_latents,
|
| 908 |
+
clean_latent_indices=clean_latent_indices,
|
| 909 |
+
clean_latents_2x=clean_latents_2x,
|
| 910 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 911 |
+
clean_latents_4x=clean_latents_4x,
|
| 912 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 913 |
+
**prompt_emb_neg,
|
| 914 |
+
**extra_input
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 918 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 919 |
+
|
| 920 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 921 |
+
# 只使用Text CFG
|
| 922 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 923 |
+
new_latents,
|
| 924 |
+
timestep=timestep_tensor,
|
| 925 |
+
cam_emb=camera_embedding,
|
| 926 |
+
modality_inputs=modality_inputs,
|
| 927 |
+
latent_indices=latent_indices,
|
| 928 |
+
clean_latents=clean_latents,
|
| 929 |
+
clean_latent_indices=clean_latent_indices,
|
| 930 |
+
clean_latents_2x=clean_latents_2x,
|
| 931 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 932 |
+
clean_latents_4x=clean_latents_4x,
|
| 933 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 934 |
+
**prompt_emb_pos,
|
| 935 |
+
**extra_input
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
noise_pred_uncond, moe_loess= pipe.dit(
|
| 939 |
+
new_latents,
|
| 940 |
+
timestep=timestep_tensor,
|
| 941 |
+
cam_emb=camera_embedding,
|
| 942 |
+
modality_inputs=modality_inputs,
|
| 943 |
+
latent_indices=latent_indices,
|
| 944 |
+
clean_latents=clean_latents,
|
| 945 |
+
clean_latent_indices=clean_latent_indices,
|
| 946 |
+
clean_latents_2x=clean_latents_2x,
|
| 947 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 948 |
+
clean_latents_4x=clean_latents_4x,
|
| 949 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 950 |
+
**prompt_emb_neg,
|
| 951 |
+
**extra_input
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 955 |
+
|
| 956 |
+
else:
|
| 957 |
+
# 标准推理(无CFG)
|
| 958 |
+
noise_pred, moe_loess = pipe.dit(
|
| 959 |
+
new_latents,
|
| 960 |
+
timestep=timestep_tensor,
|
| 961 |
+
cam_emb=camera_embedding,
|
| 962 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 963 |
+
latent_indices=latent_indices,
|
| 964 |
+
clean_latents=clean_latents,
|
| 965 |
+
clean_latent_indices=clean_latent_indices,
|
| 966 |
+
clean_latents_2x=clean_latents_2x,
|
| 967 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 968 |
+
clean_latents_4x=clean_latents_4x,
|
| 969 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 970 |
+
**prompt_emb_pos,
|
| 971 |
+
**extra_input
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 975 |
+
|
| 976 |
+
# 更新历史
|
| 977 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 978 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 979 |
+
|
| 980 |
+
# 维护滑动窗口
|
| 981 |
+
if history_latents.shape[1] > max_history_frames:
|
| 982 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 983 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 984 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 985 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 986 |
+
|
| 987 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 988 |
+
|
| 989 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 990 |
+
total_generated += current_generation
|
| 991 |
+
|
| 992 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 993 |
+
|
| 994 |
+
# 12. 解码和保存
|
| 995 |
+
print("\n🔧 解码生成的视频...")
|
| 996 |
+
|
| 997 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 998 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 999 |
+
|
| 1000 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 1001 |
+
|
| 1002 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 1003 |
+
|
| 1004 |
+
print(f"Saving video to {output_path}")
|
| 1005 |
+
|
| 1006 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 1007 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 1008 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 1009 |
+
|
| 1010 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 1011 |
+
for frame in video_np:
|
| 1012 |
+
writer.append_data(frame)
|
| 1013 |
+
|
| 1014 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 1015 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 1016 |
+
print(f"使用模态: {modality_type}")
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
def main():
|
| 1020 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 1021 |
+
|
| 1022 |
+
# 基础参数
|
| 1023 |
+
parser.add_argument("--condition_pth", type=str,
|
| 1024 |
+
#default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone/00500210001_0012150_0012450/encoded_video.pth")
|
| 1025 |
+
default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 1026 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 1027 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 1028 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 1029 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 1030 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 1031 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 1032 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 1033 |
+
parser.add_argument("--use_real_poses", default=False)
|
| 1034 |
+
parser.add_argument("--dit_path", type=str,
|
| 1035 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt")
|
| 1036 |
+
parser.add_argument("--output_path", type=str,
|
| 1037 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 1038 |
+
parser.add_argument("--prompt", type=str,
|
| 1039 |
+
default="A car is driving")
|
| 1040 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 1041 |
+
|
| 1042 |
+
# 模态类型参数
|
| 1043 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="nuscenes",
|
| 1044 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 1045 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 1046 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 1047 |
+
|
| 1048 |
+
# CFG参数
|
| 1049 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 1050 |
+
help="使用Camera CFG")
|
| 1051 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 1052 |
+
help="Camera guidance scale for CFG")
|
| 1053 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 1054 |
+
help="Text guidance scale for CFG")
|
| 1055 |
+
|
| 1056 |
+
# MoE参数
|
| 1057 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 1058 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 1059 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 1060 |
+
parser.add_argument("--direction", type=str, default="left")
|
| 1061 |
+
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
|
| 1062 |
+
help="使用数据集中的ground truth prompt embedding")
|
| 1063 |
+
|
| 1064 |
+
args = parser.parse_args()
|
| 1065 |
+
|
| 1066 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 1067 |
+
print(f"模态类型: {args.modality_type}")
|
| 1068 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 1069 |
+
if args.use_camera_cfg:
|
| 1070 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 1071 |
+
print(f"使用GT Prompt: {args.use_gt_prompt}")
|
| 1072 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 1073 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 1074 |
+
print(f"DiT{args.dit_path}")
|
| 1075 |
+
|
| 1076 |
+
# 验证NuScenes参数
|
| 1077 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 1078 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 1079 |
+
|
| 1080 |
+
inference_moe_framepack_sliding_window(
|
| 1081 |
+
condition_pth_path=args.condition_pth,
|
| 1082 |
+
dit_path=args.dit_path,
|
| 1083 |
+
output_path=args.output_path,
|
| 1084 |
+
start_frame=args.start_frame,
|
| 1085 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 1086 |
+
frames_per_generation=args.frames_per_generation,
|
| 1087 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 1088 |
+
max_history_frames=args.max_history_frames,
|
| 1089 |
+
device=args.device,
|
| 1090 |
+
prompt=args.prompt,
|
| 1091 |
+
modality_type=args.modality_type,
|
| 1092 |
+
use_real_poses=args.use_real_poses,
|
| 1093 |
+
scene_info_path=args.scene_info_path,
|
| 1094 |
+
# CFG参数
|
| 1095 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 1096 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 1097 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1098 |
+
# MoE参数
|
| 1099 |
+
moe_num_experts=args.moe_num_experts,
|
| 1100 |
+
moe_top_k=args.moe_top_k,
|
| 1101 |
+
moe_hidden_dim=args.moe_hidden_dim,
|
| 1102 |
+
direction=args.direction,
|
| 1103 |
+
use_gt_prompt=args.use_gt_prompt
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
if __name__ == "__main__":
|
| 1108 |
+
main()
|
scripts/infer_recam.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData
|
| 5 |
+
import torch, os, imageio, argparse
|
| 6 |
+
from torchvision.transforms import v2
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torchvision
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
class Camera(object):
|
| 15 |
+
def __init__(self, c2w):
|
| 16 |
+
c2w_mat = np.array(c2w).reshape(4, 4)
|
| 17 |
+
self.c2w_mat = c2w_mat
|
| 18 |
+
self.w2c_mat = np.linalg.inv(c2w_mat)
|
| 19 |
+
|
| 20 |
+
class TextVideoCameraDataset(torch.utils.data.Dataset):
|
| 21 |
+
def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, condition_frames=40, target_frames=20):
|
| 22 |
+
metadata = pd.read_csv(metadata_path)
|
| 23 |
+
self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]]
|
| 24 |
+
self.text = metadata["text"].to_list()
|
| 25 |
+
|
| 26 |
+
self.max_num_frames = max_num_frames
|
| 27 |
+
self.frame_interval = frame_interval
|
| 28 |
+
self.num_frames = num_frames
|
| 29 |
+
self.height = height
|
| 30 |
+
self.width = width
|
| 31 |
+
self.is_i2v = is_i2v
|
| 32 |
+
self.args = args
|
| 33 |
+
self.cam_type = self.args.cam_type
|
| 34 |
+
|
| 35 |
+
# 🔧 新增:保存帧数配置
|
| 36 |
+
self.condition_frames = condition_frames
|
| 37 |
+
self.target_frames = target_frames
|
| 38 |
+
|
| 39 |
+
self.frame_process = v2.Compose([
|
| 40 |
+
v2.CenterCrop(size=(height, width)),
|
| 41 |
+
v2.Resize(size=(height, width), antialias=True),
|
| 42 |
+
v2.ToTensor(),
|
| 43 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
def crop_and_resize(self, image):
|
| 47 |
+
width, height = image.size
|
| 48 |
+
scale = max(self.width / width, self.height / height)
|
| 49 |
+
image = torchvision.transforms.functional.resize(
|
| 50 |
+
image,
|
| 51 |
+
(round(height*scale), round(width*scale)),
|
| 52 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 53 |
+
)
|
| 54 |
+
return image
|
| 55 |
+
|
| 56 |
+
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
| 57 |
+
reader = imageio.get_reader(file_path)
|
| 58 |
+
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
| 59 |
+
reader.close()
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
frames = []
|
| 63 |
+
first_frame = None
|
| 64 |
+
for frame_id in range(num_frames):
|
| 65 |
+
frame = reader.get_data(start_frame_id + frame_id * interval)
|
| 66 |
+
frame = Image.fromarray(frame)
|
| 67 |
+
frame = self.crop_and_resize(frame)
|
| 68 |
+
if first_frame is None:
|
| 69 |
+
first_frame = np.array(frame)
|
| 70 |
+
frame = frame_process(frame)
|
| 71 |
+
frames.append(frame)
|
| 72 |
+
reader.close()
|
| 73 |
+
|
| 74 |
+
frames = torch.stack(frames, dim=0)
|
| 75 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 76 |
+
|
| 77 |
+
if self.is_i2v:
|
| 78 |
+
return frames, first_frame
|
| 79 |
+
else:
|
| 80 |
+
return frames
|
| 81 |
+
|
| 82 |
+
def is_image(self, file_path):
|
| 83 |
+
file_ext_name = file_path.split(".")[-1]
|
| 84 |
+
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
| 85 |
+
return True
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
def load_video(self, file_path):
|
| 89 |
+
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
| 90 |
+
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
| 91 |
+
return frames
|
| 92 |
+
|
| 93 |
+
def parse_matrix(self, matrix_str):
|
| 94 |
+
rows = matrix_str.strip().split('] [')
|
| 95 |
+
matrix = []
|
| 96 |
+
for row in rows:
|
| 97 |
+
row = row.replace('[', '').replace(']', '')
|
| 98 |
+
matrix.append(list(map(float, row.split())))
|
| 99 |
+
return np.array(matrix)
|
| 100 |
+
|
| 101 |
+
def get_relative_pose(self, cam_params):
|
| 102 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 103 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 104 |
+
|
| 105 |
+
cam_to_origin = 0
|
| 106 |
+
target_cam_c2w = np.array([
|
| 107 |
+
[1, 0, 0, 0],
|
| 108 |
+
[0, 1, 0, -cam_to_origin],
|
| 109 |
+
[0, 0, 1, 0],
|
| 110 |
+
[0, 0, 0, 1]
|
| 111 |
+
])
|
| 112 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 113 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 114 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 115 |
+
return ret_poses
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, data_id):
|
| 118 |
+
text = self.text[data_id]
|
| 119 |
+
path = self.path[data_id]
|
| 120 |
+
video = self.load_video(path)
|
| 121 |
+
if video is None:
|
| 122 |
+
raise ValueError(f"{path} is not a valid video.")
|
| 123 |
+
num_frames = video.shape[1]
|
| 124 |
+
assert num_frames == 81
|
| 125 |
+
data = {"text": text, "video": video, "path": path}
|
| 126 |
+
|
| 127 |
+
# load camera
|
| 128 |
+
tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json"
|
| 129 |
+
with open(tgt_camera_path, 'r') as file:
|
| 130 |
+
cam_data = json.load(file)
|
| 131 |
+
|
| 132 |
+
# 🔧 修改:生成target_frames长度的相机轨迹
|
| 133 |
+
cam_idx = np.linspace(0, 80, self.target_frames, dtype=int).tolist() # 改为target_frames长度
|
| 134 |
+
traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx]
|
| 135 |
+
traj = np.stack(traj).transpose(0, 2, 1)
|
| 136 |
+
c2ws = []
|
| 137 |
+
for c2w in traj:
|
| 138 |
+
c2w = c2w[:, [1, 2, 0, 3]]
|
| 139 |
+
c2w[:3, 1] *= -1.
|
| 140 |
+
c2w[:3, 3] /= 100
|
| 141 |
+
c2ws.append(c2w)
|
| 142 |
+
tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
|
| 143 |
+
relative_poses = []
|
| 144 |
+
for i in range(len(tgt_cam_params)):
|
| 145 |
+
relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
|
| 146 |
+
relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
|
| 147 |
+
pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4]
|
| 148 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12]
|
| 149 |
+
data['camera'] = pose_embedding.to(torch.bfloat16)
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
def __len__(self):
|
| 153 |
+
return len(self.path)
|
| 154 |
+
|
| 155 |
+
def parse_args():
|
| 156 |
+
parser = argparse.ArgumentParser(description="ReCamMaster Inference")
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--dataset_path",
|
| 159 |
+
type=str,
|
| 160 |
+
default="./example_test_data",
|
| 161 |
+
help="The path of the Dataset.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--ckpt_path",
|
| 165 |
+
type=str,
|
| 166 |
+
default="/share_zhuyixuan05/zhuyixuan05/recam_future_checkpoint/step1000.ckpt",
|
| 167 |
+
help="Path to save the model.",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--output_dir",
|
| 171 |
+
type=str,
|
| 172 |
+
default="./results",
|
| 173 |
+
help="Path to save the results.",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--dataloader_num_workers",
|
| 177 |
+
type=int,
|
| 178 |
+
default=1,
|
| 179 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--cam_type",
|
| 183 |
+
type=str,
|
| 184 |
+
default=1,
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--cfg_scale",
|
| 188 |
+
type=float,
|
| 189 |
+
default=5.0,
|
| 190 |
+
)
|
| 191 |
+
# 🔧 新增:condition和target帧数参数
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--condition_frames",
|
| 194 |
+
type=int,
|
| 195 |
+
default=15,
|
| 196 |
+
help="Number of condition frames",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--target_frames",
|
| 200 |
+
type=int,
|
| 201 |
+
default=15,
|
| 202 |
+
help="Number of target frames to generate",
|
| 203 |
+
)
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
return args
|
| 206 |
+
|
| 207 |
+
if __name__ == '__main__':
|
| 208 |
+
args = parse_args()
|
| 209 |
+
|
| 210 |
+
# 1. Load Wan2.1 pre-trained models
|
| 211 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 212 |
+
model_manager.load_models([
|
| 213 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 214 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 215 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 216 |
+
])
|
| 217 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 218 |
+
|
| 219 |
+
# 2. Initialize additional modules introduced in ReCamMaster
|
| 220 |
+
dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 221 |
+
for block in pipe.dit.blocks:
|
| 222 |
+
block.cam_encoder = nn.Linear(12, dim)
|
| 223 |
+
block.projector = nn.Linear(dim, dim)
|
| 224 |
+
block.cam_encoder.weight.data.zero_()
|
| 225 |
+
block.cam_encoder.bias.data.zero_()
|
| 226 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 227 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 228 |
+
|
| 229 |
+
# 3. Load ReCamMaster checkpoint
|
| 230 |
+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
| 231 |
+
pipe.dit.load_state_dict(state_dict, strict=True)
|
| 232 |
+
pipe.to("cuda")
|
| 233 |
+
pipe.to(dtype=torch.bfloat16)
|
| 234 |
+
|
| 235 |
+
output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}")
|
| 236 |
+
if not os.path.exists(output_dir):
|
| 237 |
+
os.makedirs(output_dir)
|
| 238 |
+
|
| 239 |
+
# 4. Prepare test data (source video, target camera, target trajectory)
|
| 240 |
+
dataset = TextVideoCameraDataset(
|
| 241 |
+
args.dataset_path,
|
| 242 |
+
os.path.join(args.dataset_path, "metadata.csv"),
|
| 243 |
+
args,
|
| 244 |
+
condition_frames=args.condition_frames, # 🔧 传递参数
|
| 245 |
+
target_frames=args.target_frames, # 🔧 传递参数
|
| 246 |
+
)
|
| 247 |
+
dataloader = torch.utils.data.DataLoader(
|
| 248 |
+
dataset,
|
| 249 |
+
shuffle=False,
|
| 250 |
+
batch_size=1,
|
| 251 |
+
num_workers=args.dataloader_num_workers
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 5. Inference
|
| 255 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 256 |
+
target_text = batch["text"]
|
| 257 |
+
source_video = batch["video"]
|
| 258 |
+
target_camera = batch["camera"]
|
| 259 |
+
|
| 260 |
+
video = pipe(
|
| 261 |
+
prompt=target_text,
|
| 262 |
+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的��景,三条腿,背景人很多,倒着走",
|
| 263 |
+
source_video=source_video,
|
| 264 |
+
target_camera=target_camera,
|
| 265 |
+
cfg_scale=args.cfg_scale,
|
| 266 |
+
num_inference_steps=50,
|
| 267 |
+
seed=0,
|
| 268 |
+
tiled=True,
|
| 269 |
+
condition_frames=args.condition_frames,
|
| 270 |
+
target_frames=args.target_frames,
|
| 271 |
+
)
|
| 272 |
+
save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5)
|
scripts/infer_rlbench.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
import argparse
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 15 |
+
"""
|
| 16 |
+
从pth文件加载预编码的视频数据
|
| 17 |
+
Args:
|
| 18 |
+
pth_path: pth文件路径
|
| 19 |
+
start_frame: 起始帧索引(基于压缩后的latent帧数)
|
| 20 |
+
num_frames: 需要的帧数(基于压缩后的latent帧数)
|
| 21 |
+
Returns:
|
| 22 |
+
condition_latents: [C, T, H, W] 格式的latent tensor
|
| 23 |
+
"""
|
| 24 |
+
print(f"Loading encoded video from {pth_path}")
|
| 25 |
+
|
| 26 |
+
# 加载编码数据
|
| 27 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 28 |
+
|
| 29 |
+
# 获取latent数据
|
| 30 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 31 |
+
|
| 32 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 33 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 34 |
+
|
| 35 |
+
# 检查帧数是否足够
|
| 36 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 37 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 38 |
+
|
| 39 |
+
# 提取指定帧数
|
| 40 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 41 |
+
|
| 42 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 43 |
+
|
| 44 |
+
return condition_latents, encoded_data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 48 |
+
"""
|
| 49 |
+
计算相机B相对于相机A的相对位姿矩阵
|
| 50 |
+
"""
|
| 51 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 52 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 53 |
+
|
| 54 |
+
if use_torch:
|
| 55 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 56 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 57 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 58 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 59 |
+
|
| 60 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 61 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 62 |
+
else:
|
| 63 |
+
if not isinstance(pose_a, np.ndarray):
|
| 64 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 65 |
+
if not isinstance(pose_b, np.ndarray):
|
| 66 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 67 |
+
|
| 68 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 69 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 70 |
+
|
| 71 |
+
return relative_pose
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
|
| 75 |
+
"""
|
| 76 |
+
从实际相机数据生成pose embeddings
|
| 77 |
+
Args:
|
| 78 |
+
cam_data: 相机外参数据
|
| 79 |
+
start_frame: 起始帧(原始帧索引)
|
| 80 |
+
condition_frames: 条件帧数(压缩后)
|
| 81 |
+
target_frames: 目标帧数(压缩后)
|
| 82 |
+
"""
|
| 83 |
+
time_compression_ratio = 4
|
| 84 |
+
total_frames = condition_frames + target_frames
|
| 85 |
+
|
| 86 |
+
# 获取相机外参序列
|
| 87 |
+
cam_extrinsic = cam_data # [N, 4, 4]
|
| 88 |
+
|
| 89 |
+
# 计算原始帧索引
|
| 90 |
+
start_frame_original = start_frame * time_compression_ratio
|
| 91 |
+
end_frame_original = (start_frame + total_frames) * time_compression_ratio
|
| 92 |
+
|
| 93 |
+
print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
|
| 94 |
+
|
| 95 |
+
# 计算相对pose
|
| 96 |
+
relative_poses = []
|
| 97 |
+
for i in range(total_frames):
|
| 98 |
+
frame_idx = start_frame_original + i * time_compression_ratio
|
| 99 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
relative_poses.append(torch.as_tensor(cam_prev)) # 取前3行
|
| 107 |
+
|
| 108 |
+
print(cam_prev)
|
| 109 |
+
# 组装pose embedding
|
| 110 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 111 |
+
# print('pose_embedding init:',pose_embedding[0])
|
| 112 |
+
print('pose_embedding:',pose_embedding)
|
| 113 |
+
# assert False
|
| 114 |
+
|
| 115 |
+
# pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 116 |
+
|
| 117 |
+
# 添加mask信息
|
| 118 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 119 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 120 |
+
mask = mask.view(-1, 1)
|
| 121 |
+
|
| 122 |
+
# 组合pose和mask
|
| 123 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 124 |
+
|
| 125 |
+
print(f"Generated camera embedding shape: {camera_embedding.shape}")
|
| 126 |
+
|
| 127 |
+
return camera_embedding.to(torch.bfloat16)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
|
| 131 |
+
"""
|
| 132 |
+
根据指定方向生成相机pose序列(合成数据)
|
| 133 |
+
"""
|
| 134 |
+
time_compression_ratio = 4
|
| 135 |
+
total_frames = condition_frames + target_frames
|
| 136 |
+
|
| 137 |
+
poses = []
|
| 138 |
+
|
| 139 |
+
for i in range(total_frames):
|
| 140 |
+
t = i / max(1, total_frames - 1) # 0 to 1
|
| 141 |
+
|
| 142 |
+
# 创建变换矩阵
|
| 143 |
+
pose = np.eye(4, dtype=np.float32)
|
| 144 |
+
|
| 145 |
+
if direction == "forward":
|
| 146 |
+
# 前进:沿z轴负方向移动
|
| 147 |
+
pose[2, 3] = -t * 0.04
|
| 148 |
+
print('forward!')
|
| 149 |
+
|
| 150 |
+
elif direction == "backward":
|
| 151 |
+
# 后退:沿z轴正方向移动
|
| 152 |
+
pose[2, 3] = t * 2.0
|
| 153 |
+
|
| 154 |
+
elif direction == "left_turn":
|
| 155 |
+
# 左转:前进 + 绕y轴旋转
|
| 156 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 157 |
+
pose[0, 3] = t * 0.02 # 左移
|
| 158 |
+
# 添加旋转
|
| 159 |
+
yaw = t * 1
|
| 160 |
+
pose[0, 0] = np.cos(yaw)
|
| 161 |
+
pose[0, 2] = np.sin(yaw)
|
| 162 |
+
pose[2, 0] = -np.sin(yaw)
|
| 163 |
+
pose[2, 2] = np.cos(yaw)
|
| 164 |
+
|
| 165 |
+
elif direction == "right_turn":
|
| 166 |
+
# 右转:前进 + 绕y轴反向旋转
|
| 167 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 168 |
+
pose[0, 3] = -t * 0.02 # 右移
|
| 169 |
+
# 添加旋转
|
| 170 |
+
yaw = - t * 1
|
| 171 |
+
pose[0, 0] = np.cos(yaw)
|
| 172 |
+
pose[0, 2] = np.sin(yaw)
|
| 173 |
+
pose[2, 0] = -np.sin(yaw)
|
| 174 |
+
pose[2, 2] = np.cos(yaw)
|
| 175 |
+
|
| 176 |
+
poses.append(pose)
|
| 177 |
+
|
| 178 |
+
# 计算相对pose
|
| 179 |
+
relative_poses = []
|
| 180 |
+
for i in range(len(poses) - 1):
|
| 181 |
+
relative_pose = compute_relative_pose(poses[i], poses[i + 1])
|
| 182 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
|
| 183 |
+
|
| 184 |
+
# 为了匹配模型输入,需要确保帧数正确
|
| 185 |
+
if len(relative_poses) < total_frames:
|
| 186 |
+
# 补充最后一帧
|
| 187 |
+
relative_poses.append(relative_poses[-1])
|
| 188 |
+
|
| 189 |
+
pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
|
| 190 |
+
|
| 191 |
+
print('pose_embedding init:',pose_embedding[0])
|
| 192 |
+
|
| 193 |
+
print('pose_embedding:',pose_embedding[-5:])
|
| 194 |
+
|
| 195 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 196 |
+
|
| 197 |
+
# 添加mask信息
|
| 198 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 199 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 200 |
+
mask = mask.view(-1, 1)
|
| 201 |
+
|
| 202 |
+
# 组合pose和mask
|
| 203 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 204 |
+
|
| 205 |
+
print(f"Generated {direction} movement poses:")
|
| 206 |
+
print(f" Total frames: {total_frames}")
|
| 207 |
+
print(f" Camera embedding shape: {camera_embedding.shape}")
|
| 208 |
+
|
| 209 |
+
return camera_embedding.to(torch.bfloat16)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def inference_sekai_video_from_pth(
|
| 213 |
+
condition_pth_path,
|
| 214 |
+
dit_path,
|
| 215 |
+
output_path="sekai/infer_results/output_sekai.mp4",
|
| 216 |
+
start_frame=0,
|
| 217 |
+
condition_frames=10, # 压缩后的帧数
|
| 218 |
+
target_frames=2, # 压缩后的帧数
|
| 219 |
+
device="cuda",
|
| 220 |
+
prompt="a robotic arm executing precise manipulation tasks on a clean, organized desk",
|
| 221 |
+
direction="forward",
|
| 222 |
+
use_real_poses=True
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
从pth文件进行Sekai视频推理
|
| 226 |
+
"""
|
| 227 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 228 |
+
|
| 229 |
+
print(f"Setting up models for {direction} movement...")
|
| 230 |
+
|
| 231 |
+
# 1. Load models
|
| 232 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 233 |
+
model_manager.load_models([
|
| 234 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 235 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 236 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 237 |
+
])
|
| 238 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 239 |
+
|
| 240 |
+
# Add camera components to DiT
|
| 241 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 242 |
+
for block in pipe.dit.blocks:
|
| 243 |
+
block.cam_encoder = nn.Linear(30, dim) # 13维embedding (12D pose + 1D mask)
|
| 244 |
+
block.projector = nn.Linear(dim, dim)
|
| 245 |
+
block.cam_encoder.weight.data.zero_()
|
| 246 |
+
block.cam_encoder.bias.data.zero_()
|
| 247 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 248 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 249 |
+
|
| 250 |
+
# Load trained DiT weights
|
| 251 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 252 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 253 |
+
pipe = pipe.to(device)
|
| 254 |
+
pipe.scheduler.set_timesteps(50)
|
| 255 |
+
|
| 256 |
+
print("Loading condition video from pth...")
|
| 257 |
+
|
| 258 |
+
# Load condition video from pth
|
| 259 |
+
condition_latents, encoded_data = load_encoded_video_from_pth(
|
| 260 |
+
condition_pth_path,
|
| 261 |
+
start_frame=start_frame,
|
| 262 |
+
num_frames=condition_frames
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 266 |
+
|
| 267 |
+
print("Processing poses...")
|
| 268 |
+
|
| 269 |
+
# 生成相机pose embedding
|
| 270 |
+
if use_real_poses and 'cam_emb' in encoded_data:
|
| 271 |
+
print("Using real camera poses from data")
|
| 272 |
+
camera_embedding = generate_camera_poses_from_data(
|
| 273 |
+
encoded_data['cam_emb'],
|
| 274 |
+
start_frame=start_frame,
|
| 275 |
+
condition_frames=condition_frames,
|
| 276 |
+
target_frames=target_frames
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
print(f"Using synthetic {direction} poses")
|
| 280 |
+
camera_embedding = generate_camera_poses(
|
| 281 |
+
direction=direction,
|
| 282 |
+
target_frames=target_frames,
|
| 283 |
+
condition_frames=condition_frames
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
|
| 289 |
+
|
| 290 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 291 |
+
|
| 292 |
+
print("Encoding prompt...")
|
| 293 |
+
|
| 294 |
+
# Encode text prompt
|
| 295 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 296 |
+
|
| 297 |
+
print("Generating video...")
|
| 298 |
+
|
| 299 |
+
# Generate target latents
|
| 300 |
+
batch_size = 1
|
| 301 |
+
channels = condition_latents.shape[1]
|
| 302 |
+
latent_height = condition_latents.shape[3]
|
| 303 |
+
latent_width = condition_latents.shape[4]
|
| 304 |
+
|
| 305 |
+
# 空间裁剪以节省内存(如果需要)
|
| 306 |
+
target_height, target_width = 64, 64
|
| 307 |
+
|
| 308 |
+
if latent_height > target_height or latent_width > target_width:
|
| 309 |
+
# 中心裁剪
|
| 310 |
+
h_start = (latent_height - target_height) // 2
|
| 311 |
+
w_start = (latent_width - target_width) // 2
|
| 312 |
+
condition_latents = condition_latents[:, :, :,
|
| 313 |
+
h_start:h_start+target_height,
|
| 314 |
+
w_start:w_start+target_width]
|
| 315 |
+
latent_height = target_height
|
| 316 |
+
latent_width = target_width
|
| 317 |
+
|
| 318 |
+
# Initialize target latents with noise
|
| 319 |
+
target_latents = torch.randn(
|
| 320 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 321 |
+
device=device, dtype=pipe.torch_dtype
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
print(f"Condition latents shape: {condition_latents.shape}")
|
| 325 |
+
print(f"Target latents shape: {target_latents.shape}")
|
| 326 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 327 |
+
|
| 328 |
+
# Combine condition and target latents
|
| 329 |
+
combined_latents = torch.cat([condition_latents, target_latents], dim=2)
|
| 330 |
+
print(f"Combined latents shape: {combined_latents.shape}")
|
| 331 |
+
|
| 332 |
+
# Prepare extra inputs
|
| 333 |
+
extra_input = pipe.prepare_extra_input(combined_latents)
|
| 334 |
+
|
| 335 |
+
# Denoising loop
|
| 336 |
+
timesteps = pipe.scheduler.timesteps
|
| 337 |
+
|
| 338 |
+
for i, timestep in enumerate(timesteps):
|
| 339 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 340 |
+
|
| 341 |
+
# Prepare timestep
|
| 342 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 343 |
+
|
| 344 |
+
# Predict noise
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
noise_pred = pipe.dit(
|
| 347 |
+
combined_latents,
|
| 348 |
+
timestep=timestep_tensor,
|
| 349 |
+
cam_emb=camera_embedding,
|
| 350 |
+
**prompt_emb,
|
| 351 |
+
**extra_input
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Update only target part
|
| 355 |
+
target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
|
| 356 |
+
target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
|
| 357 |
+
|
| 358 |
+
# Update combined latents
|
| 359 |
+
combined_latents[:, :, condition_frames:, :, :] = target_latents
|
| 360 |
+
|
| 361 |
+
print("Decoding video...")
|
| 362 |
+
|
| 363 |
+
# Decode final video
|
| 364 |
+
final_video = torch.cat([condition_latents, target_latents], dim=2)
|
| 365 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 366 |
+
|
| 367 |
+
# Save video
|
| 368 |
+
print(f"Saving video to {output_path}")
|
| 369 |
+
|
| 370 |
+
# Convert to numpy and save
|
| 371 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 372 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
|
| 373 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 374 |
+
|
| 375 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 376 |
+
for frame in video_np:
|
| 377 |
+
writer.append_data(frame)
|
| 378 |
+
|
| 379 |
+
print(f"Video generation completed! Saved to {output_path}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def main():
|
| 383 |
+
parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
|
| 384 |
+
parser.add_argument("--condition_pth", type=str,
|
| 385 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench/OpenBox_demo_49/encoded_video.pth")
|
| 386 |
+
parser.add_argument("--start_frame", type=int, default=0,
|
| 387 |
+
help="Starting frame index (compressed latent frames)")
|
| 388 |
+
parser.add_argument("--condition_frames", type=int, default=8,
|
| 389 |
+
help="Number of condition frames (compressed latent frames)")
|
| 390 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 391 |
+
help="Number of target frames to generate (compressed latent frames)")
|
| 392 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 393 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 394 |
+
help="Direction of camera movement (if not using real poses)")
|
| 395 |
+
parser.add_argument("--use_real_poses", default=False,
|
| 396 |
+
help="Use real camera poses from data")
|
| 397 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/RLBench-train/step2000_dynamic.ckpt",
|
| 398 |
+
help="Path to trained DiT checkpoint")
|
| 399 |
+
parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/rlbench/infer_results/output_rl_2.mp4',
|
| 400 |
+
help="Output video path")
|
| 401 |
+
parser.add_argument("--prompt", type=str,
|
| 402 |
+
default="a robotic arm executing precise manipulation tasks on a clean, organized desk",
|
| 403 |
+
help="Text prompt for generation")
|
| 404 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 405 |
+
help="Device to run inference on")
|
| 406 |
+
|
| 407 |
+
args = parser.parse_args()
|
| 408 |
+
|
| 409 |
+
# 生成输出路径
|
| 410 |
+
if args.output_path is None:
|
| 411 |
+
pth_filename = os.path.basename(args.condition_pth)
|
| 412 |
+
name_parts = os.path.splitext(pth_filename)
|
| 413 |
+
output_dir = "rlbench/infer_results"
|
| 414 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 415 |
+
|
| 416 |
+
if args.use_real_poses:
|
| 417 |
+
output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 418 |
+
else:
|
| 419 |
+
output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 420 |
+
|
| 421 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 422 |
+
else:
|
| 423 |
+
output_path = args.output_path
|
| 424 |
+
|
| 425 |
+
print(f"Input pth: {args.condition_pth}")
|
| 426 |
+
print(f"Start frame: {args.start_frame} (compressed)")
|
| 427 |
+
print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
|
| 428 |
+
print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
|
| 429 |
+
print(f"Use real poses: {args.use_real_poses}")
|
| 430 |
+
print(f"Output video will be saved to: {output_path}")
|
| 431 |
+
|
| 432 |
+
inference_sekai_video_from_pth(
|
| 433 |
+
condition_pth_path=args.condition_pth,
|
| 434 |
+
dit_path=args.dit_path,
|
| 435 |
+
output_path=output_path,
|
| 436 |
+
start_frame=args.start_frame,
|
| 437 |
+
condition_frames=args.condition_frames,
|
| 438 |
+
target_frames=args.target_frames,
|
| 439 |
+
device=args.device,
|
| 440 |
+
prompt=args.prompt,
|
| 441 |
+
direction=args.direction,
|
| 442 |
+
use_real_poses=args.use_real_poses
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
main()
|