EvanEternal commited on
Commit
08bf07d
·
verified ·
1 Parent(s): 9ecdc6d

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. icons/move_backward.png +0 -0
  2. icons/move_forward.png +0 -0
  3. icons/move_left.png +0 -0
  4. icons/move_right.png +0 -0
  5. icons/not_move_backward.png +0 -0
  6. icons/not_move_forward.png +0 -0
  7. icons/not_move_left.png +0 -0
  8. icons/not_move_right.png +0 -0
  9. icons/not_turn_down.png +0 -0
  10. icons/not_turn_left.png +0 -0
  11. icons/not_turn_right.png +0 -0
  12. icons/not_turn_up.png +0 -0
  13. icons/turn_down.png +0 -0
  14. icons/turn_left.png +0 -0
  15. icons/turn_right.png +0 -0
  16. icons/turn_up.png +0 -0
  17. models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt +0 -0
  18. models/Astra/checkpoints/README.md +5 -0
  19. scripts/add_text_emb.py +161 -0
  20. scripts/add_text_emb_rl.py +161 -0
  21. scripts/add_text_emb_spatialvid.py +173 -0
  22. scripts/analyze_openx.py +243 -0
  23. scripts/analyze_pose.py +188 -0
  24. scripts/batch_drone.py +44 -0
  25. scripts/batch_infer.py +186 -0
  26. scripts/batch_nus.py +42 -0
  27. scripts/batch_rt.py +41 -0
  28. scripts/batch_spa.py +43 -0
  29. scripts/batch_walk.py +42 -0
  30. scripts/check.py +263 -0
  31. scripts/decode_openx.py +428 -0
  32. scripts/download_recam.py +7 -0
  33. scripts/download_wan2.1.py +5 -0
  34. scripts/encode_dynamic_videos.py +141 -0
  35. scripts/encode_openx.py +466 -0
  36. scripts/encode_rlbench_video.py +170 -0
  37. scripts/encode_sekai_video.py +162 -0
  38. scripts/encode_sekai_walking.py +249 -0
  39. scripts/encode_spatialvid.py +409 -0
  40. scripts/encode_spatialvid_first_frame.py +285 -0
  41. scripts/hud_logo.py +40 -0
  42. scripts/infer_demo.py +1458 -0
  43. scripts/infer_moe.py +1023 -0
  44. scripts/infer_moe_spatialvid.py +1008 -0
  45. scripts/infer_moe_test.py +976 -0
  46. scripts/infer_nus.py +500 -0
  47. scripts/infer_openx.py +614 -0
  48. scripts/infer_origin.py +1108 -0
  49. scripts/infer_recam.py +272 -0
  50. scripts/infer_rlbench.py +447 -0
icons/move_backward.png ADDED
icons/move_forward.png ADDED
icons/move_left.png ADDED
icons/move_right.png ADDED
icons/not_move_backward.png ADDED
icons/not_move_forward.png ADDED
icons/not_move_left.png ADDED
icons/not_move_right.png ADDED
icons/not_turn_down.png ADDED
icons/not_turn_left.png ADDED
icons/not_turn_right.png ADDED
icons/not_turn_up.png ADDED
icons/turn_down.png ADDED
icons/turn_left.png ADDED
icons/turn_right.png ADDED
icons/turn_up.png ADDED
models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt ADDED
File without changes
models/Astra/checkpoints/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # ReCamMaster: Camera-Controlled Generative Rendering from A Single Video
5
+ Please refer to the [Github](https://github.com/KwaiVGI/ReCamMaster) README for usage.
scripts/add_text_emb.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False)
96
+ missing_keys = [key for key in required_keys if key not in data]
97
+
98
+ if missing_keys:
99
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
100
+ else:
101
+ print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
102
+ continue
103
+ # with np.load(scene_cam_path) as data:
104
+ # cam_data = data.files
105
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
106
+ # with open(scene_cam_path, 'rb') as f:
107
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
108
+
109
+
110
+
111
+ # 加载和编码视频
112
+ # video_frames = encoder.load_video_frames(video_path)
113
+ # if video_frames is None:
114
+ # print(f"Failed to load video: {video_path}")
115
+ # continue
116
+
117
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
118
+ # print(video_frames.shape)
119
+ # 编码视频
120
+ with torch.no_grad():
121
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
122
+
123
+ # 编码文本
124
+ if processed_count == 0:
125
+ print('encode prompt!!!')
126
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
127
+ del encoder.pipe.prompter
128
+
129
+ data["prompt_emb"] = prompt_emb
130
+
131
+ print("已添加/更新 prompt_emb 元素")
132
+
133
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
134
+ torch.save(data, encoded_path)
135
+
136
+ # pdb.set_trace()
137
+ # 保存编码结果
138
+
139
+
140
+ print(f"Saved encoded data: {encoded_path}")
141
+ processed_count += 1
142
+
143
+ # except Exception as e:
144
+ # print(f"Error encoding scene {scene_name}: {e}")
145
+ # continue
146
+ print(processed_count)
147
+ print(f"Encoding completed! Processed {processed_count} scenes.")
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
152
+ parser.add_argument("--text_encoder_path", type=str,
153
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
154
+ parser.add_argument("--vae_path", type=str,
155
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
156
+
157
+ parser.add_argument("--output_dir",type=str,
158
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
159
+
160
+ args = parser.parse_args()
161
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/add_text_emb_rl.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False)
96
+ missing_keys = [key for key in required_keys if key not in data]
97
+
98
+ if missing_keys:
99
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
100
+ else:
101
+ print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
102
+ continue
103
+ # with np.load(scene_cam_path) as data:
104
+ # cam_data = data.files
105
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
106
+ # with open(scene_cam_path, 'rb') as f:
107
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
108
+
109
+
110
+
111
+ # 加载和编码视频
112
+ # video_frames = encoder.load_video_frames(video_path)
113
+ # if video_frames is None:
114
+ # print(f"Failed to load video: {video_path}")
115
+ # continue
116
+
117
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
118
+ # print(video_frames.shape)
119
+ # 编码视频
120
+ with torch.no_grad():
121
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
122
+
123
+ # 编码文本
124
+ if processed_count == 0:
125
+ print('encode prompt!!!')
126
+ prompt_emb = encoder.pipe.encode_prompt("a robotic arm executing precise manipulation tasks on a clean, organized desk")#A video of a scene shot using a drone's front camera + “A video of a scene shot using a pedestrian's front camera while walking”
127
+ del encoder.pipe.prompter
128
+
129
+ data["prompt_emb"] = prompt_emb
130
+
131
+ print("已添加/更新 prompt_emb 元素")
132
+
133
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
134
+ torch.save(data, encoded_path)
135
+
136
+ # pdb.set_trace()
137
+ # 保存编码结果
138
+
139
+
140
+ print(f"Saved encoded data: {encoded_path}")
141
+ processed_count += 1
142
+
143
+ # except Exception as e:
144
+ # print(f"Error encoding scene {scene_name}: {e}")
145
+ # continue
146
+ print(processed_count)
147
+ print(f"Encoding completed! Processed {processed_count} scenes.")
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench")
152
+ parser.add_argument("--text_encoder_path", type=str,
153
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
154
+ parser.add_argument("--vae_path", type=str,
155
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
156
+
157
+ parser.add_argument("--output_dir",type=str,
158
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench")
159
+
160
+ args = parser.parse_args()
161
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/add_text_emb_spatialvid.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ # print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False,
96
+ map_location="cpu")
97
+ missing_keys = [key for key in required_keys if key not in data]
98
+
99
+ if missing_keys:
100
+ print(f"警告: 文件 {encoded_path} 中缺少以下必要元素: {missing_keys}")
101
+ # else:
102
+ # # print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
103
+ # continue
104
+ # pdb.set_trace()
105
+ if data['prompt_emb']['context'].requires_grad:
106
+ print(f"警告: 文件 {encoded_path} 中存在含梯度变量,已消除")
107
+
108
+ data['prompt_emb']['context'] = data['prompt_emb']['context'].detach().clone()
109
+
110
+ # 双重保险:显式关闭梯度
111
+ data['prompt_emb']['context'].requires_grad_(False)
112
+
113
+ # 验证是否成功(可选)
114
+ assert not data['prompt_emb']['context'].requires_grad, "梯度仍未消除!"
115
+ torch.save(data, encoded_path)
116
+ # with np.load(scene_cam_path) as data:
117
+ # cam_data = data.files
118
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
119
+ # with open(scene_cam_path, 'rb') as f:
120
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
121
+
122
+
123
+
124
+ # 加载和编码视频
125
+ # video_frames = encoder.load_video_frames(video_path)
126
+ # if video_frames is None:
127
+ # print(f"Failed to load video: {video_path}")
128
+ # continue
129
+
130
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
131
+ # print(video_frames.shape)
132
+ # 编码视频
133
+ '''with torch.no_grad():
134
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
135
+
136
+ # 编码文本
137
+ if processed_count == 0:
138
+ print('encode prompt!!!')
139
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
140
+ del encoder.pipe.prompter
141
+
142
+ data["prompt_emb"] = prompt_emb
143
+
144
+ print("已添加/更新 prompt_emb 元素")
145
+
146
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
147
+ torch.save(data, encoded_path)
148
+
149
+ # pdb.set_trace()
150
+ # 保存编码结果
151
+
152
+ print(f"Saved encoded data: {encoded_path}")'''
153
+ processed_count += 1
154
+
155
+ # except Exception as e:
156
+ # print(f"Error encoding scene {scene_name}: {e}")
157
+ # continue
158
+ print(processed_count)
159
+ print(f"Encoding completed! Processed {processed_count} scenes.")
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
164
+ parser.add_argument("--text_encoder_path", type=str,
165
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
166
+ parser.add_argument("--vae_path", type=str,
167
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
168
+
169
+ parser.add_argument("--output_dir",type=str,
170
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
171
+
172
+ args = parser.parse_args()
173
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/analyze_openx.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ def analyze_openx_dataset_frame_counts(dataset_path):
6
+ """分析OpenX数据集中的帧数分布"""
7
+
8
+ print(f"🔧 分析OpenX数据集: {dataset_path}")
9
+
10
+ if not os.path.exists(dataset_path):
11
+ print(f" ⚠️ 路径不存在: {dataset_path}")
12
+ return
13
+
14
+ episode_dirs = []
15
+ total_episodes = 0
16
+ valid_episodes = 0
17
+
18
+ # 收集所有episode目录
19
+ for item in os.listdir(dataset_path):
20
+ episode_dir = os.path.join(dataset_path, item)
21
+ if os.path.isdir(episode_dir):
22
+ total_episodes += 1
23
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
24
+ if os.path.exists(encoded_path):
25
+ episode_dirs.append(episode_dir)
26
+ valid_episodes += 1
27
+
28
+ print(f"📊 总episode数: {total_episodes}")
29
+ print(f"📊 有效episode数: {valid_episodes}")
30
+
31
+ if len(episode_dirs) == 0:
32
+ print("❌ 没有找到有效的episode")
33
+ return
34
+
35
+ # 统计帧数分布
36
+ frame_counts = []
37
+ less_than_10 = 0
38
+ less_than_8 = 0
39
+ less_than_5 = 0
40
+ error_count = 0
41
+
42
+ print("🔧 开始分析帧数分布...")
43
+
44
+ for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
45
+ try:
46
+ encoded_data = torch.load(
47
+ os.path.join(episode_dir, "encoded_video.pth"),
48
+ weights_only=False,
49
+ map_location="cpu"
50
+ )
51
+
52
+ latents = encoded_data['latents'] # [C, T, H, W]
53
+ frame_count = latents.shape[1] # T维度
54
+ frame_counts.append(frame_count)
55
+
56
+ if frame_count < 10:
57
+ less_than_10 += 1
58
+ if frame_count < 8:
59
+ less_than_8 += 1
60
+ if frame_count < 5:
61
+ less_than_5 += 1
62
+
63
+ except Exception as e:
64
+ error_count += 1
65
+ if error_count <= 5: # 只打印前5个错误
66
+ print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
67
+
68
+ # 统计结果
69
+ total_valid = len(frame_counts)
70
+ print(f"\n📈 帧数分布统计:")
71
+ print(f" 总有效episodes: {total_valid}")
72
+ print(f" 错误episodes: {error_count}")
73
+ print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
74
+ print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
75
+ print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
76
+
77
+ print(f"\n🎯 关键统计:")
78
+ print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
79
+ print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
80
+ print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
81
+ print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
82
+
83
+ # 详细分布
84
+ frame_counts.sort()
85
+ print(f"\n📊 详细帧数分布:")
86
+
87
+ # 按范围统计
88
+ ranges = [
89
+ (1, 4, "1-4帧"),
90
+ (5, 7, "5-7帧"),
91
+ (8, 9, "8-9帧"),
92
+ (10, 19, "10-19帧"),
93
+ (20, 49, "20-49帧"),
94
+ (50, 99, "50-99帧"),
95
+ (100, float('inf'), "100+帧")
96
+ ]
97
+
98
+ for min_f, max_f, label in ranges:
99
+ count = sum(1 for f in frame_counts if min_f <= f <= max_f)
100
+ percentage = count / total_valid * 100
101
+ print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
102
+
103
+ # 建议的训练配置
104
+ print(f"\n💡 训练配置建议:")
105
+ time_compression_ratio = 4
106
+ min_condition_compressed = 4 // time_compression_ratio # 1帧
107
+ target_frames_compressed = 32 // time_compression_ratio # 8帧
108
+ min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
109
+
110
+ usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
111
+ usable_percentage = usable_episodes / total_valid * 100
112
+
113
+ print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
114
+ print(f" 目标帧数(压缩后): {target_frames_compressed}")
115
+ print(f" 最小所需帧数(压缩后): {min_required_compressed}")
116
+ print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
117
+
118
+ # 保存详细统计到文件
119
+ output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
120
+ with open(output_file, 'w') as f:
121
+ f.write(f"OpenX Dataset Frame Count Analysis\n")
122
+ f.write(f"Dataset Path: {dataset_path}\n")
123
+ f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
124
+
125
+ f.write(f"Total Episodes: {total_episodes}\n")
126
+ f.write(f"Valid Episodes: {total_valid}\n")
127
+ f.write(f"Error Episodes: {error_count}\n\n")
128
+
129
+ f.write(f"Frame Count Statistics:\n")
130
+ f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
131
+ f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
132
+ f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
133
+
134
+ f.write(f"Key Statistics:\n")
135
+ f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
136
+ f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
137
+ f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
138
+ f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
139
+
140
+ f.write(f"Detailed Distribution:\n")
141
+ for min_f, max_f, label in ranges:
142
+ count = sum(1 for f in frame_counts if min_f <= f <= max_f)
143
+ percentage = count / total_valid * 100
144
+ f.write(f" {label}: {count} ({percentage:.2f}%)\n")
145
+
146
+ f.write(f"\nTraining Configuration Recommendation:\n")
147
+ f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
148
+
149
+ # 写入所有帧数
150
+ f.write(f"\nAll Frame Counts:\n")
151
+ for i, count in enumerate(frame_counts):
152
+ f.write(f"{count}")
153
+ if (i + 1) % 20 == 0:
154
+ f.write("\n")
155
+ else:
156
+ f.write(", ")
157
+
158
+ print(f"\n💾 详细统计已保存到: {output_file}")
159
+
160
+ return {
161
+ 'total_valid': total_valid,
162
+ 'less_than_10': less_than_10,
163
+ 'less_than_8': less_than_8,
164
+ 'less_than_5': less_than_5,
165
+ 'frame_counts': frame_counts,
166
+ 'usable_episodes': usable_episodes
167
+ }
168
+
169
+ def quick_sample_analysis(dataset_path, sample_size=1000):
170
+ """快速采样分析,用于大数据集的初步估计"""
171
+
172
+ print(f"🚀 快速采样分析 (样本数: {sample_size})")
173
+
174
+ episode_dirs = []
175
+ for item in os.listdir(dataset_path):
176
+ episode_dir = os.path.join(dataset_path, item)
177
+ if os.path.isdir(episode_dir):
178
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
179
+ if os.path.exists(encoded_path):
180
+ episode_dirs.append(episode_dir)
181
+
182
+ if len(episode_dirs) == 0:
183
+ print("❌ 没有找到有效的episode")
184
+ return
185
+
186
+ # 随机采样
187
+ import random
188
+ sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
189
+
190
+ frame_counts = []
191
+ less_than_10 = 0
192
+
193
+ for episode_dir in tqdm(sample_dirs, desc="采样分析"):
194
+ try:
195
+ encoded_data = torch.load(
196
+ os.path.join(episode_dir, "encoded_video.pth"),
197
+ weights_only=False,
198
+ map_location="cpu"
199
+ )
200
+
201
+ frame_count = encoded_data['latents'].shape[1]
202
+ frame_counts.append(frame_count)
203
+
204
+ if frame_count < 10:
205
+ less_than_10 += 1
206
+
207
+ except Exception as e:
208
+ continue
209
+
210
+ total_sample = len(frame_counts)
211
+ percentage_less_than_10 = less_than_10 / total_sample * 100
212
+
213
+ print(f"📊 采样结果:")
214
+ print(f" 采样数量: {total_sample}")
215
+ print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
216
+ print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
217
+ print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
218
+
219
+ # 估算全数据集
220
+ total_episodes = len(episode_dirs)
221
+ estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
222
+
223
+ print(f"\n🔮 全数据集估算:")
224
+ print(f" 总episodes: {total_episodes}")
225
+ print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
226
+ print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
227
+
228
+ if __name__ == "__main__":
229
+ import argparse
230
+
231
+ parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
232
+ parser.add_argument("--dataset_path", type=str,
233
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
234
+ help="OpenX编码数据集路径")
235
+ parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
236
+ parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
237
+
238
+ args = parser.parse_args()
239
+
240
+ if args.quick:
241
+ quick_sample_analysis(args.dataset_path, args.sample_size)
242
+ else:
243
+ analyze_openx_dataset_frame_counts(args.dataset_path)
scripts/analyze_pose.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from pose_classifier import PoseClassifier
6
+ import torch
7
+ from collections import defaultdict
8
+
9
+ def analyze_turning_patterns_detailed(dataset_path, num_samples=50):
10
+ """详细分析转弯模式,基于相对于reference的pose变化"""
11
+ classifier = PoseClassifier()
12
+ samples_path = os.path.join(dataset_path, "samples")
13
+
14
+ all_analyses = []
15
+ sample_count = 0
16
+
17
+ # 用于统计每个类别的样本
18
+ class_samples = defaultdict(list)
19
+
20
+ print("=== 开始分析样本(基于相对于reference的变化)===")
21
+
22
+ for item in sorted(os.listdir(samples_path)): # 排序以便有序输出
23
+ if sample_count >= num_samples:
24
+ break
25
+
26
+ sample_dir = os.path.join(samples_path, item)
27
+ if os.path.isdir(sample_dir):
28
+ poses_path = os.path.join(sample_dir, "poses.json")
29
+ if os.path.exists(poses_path):
30
+ try:
31
+ with open(poses_path, 'r') as f:
32
+ poses_data = json.load(f)
33
+
34
+ target_relative_poses = poses_data['target_relative_poses']
35
+
36
+ if len(target_relative_poses) > 0:
37
+ # 🔧 创建相对pose向量(已经是相对于reference的)
38
+ pose_vecs = []
39
+ for pose_data in target_relative_poses:
40
+ # 相对位移(已经是相对于reference计算的)
41
+ translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
42
+
43
+ # 🔧 相对旋转(需要从current和reference计算)
44
+ current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
45
+ reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
46
+
47
+ # 计算相对旋转:q_relative = q_ref^-1 * q_current
48
+ relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
49
+
50
+ # 组合为7D向量:[relative_translation, relative_rotation]
51
+ pose_vec = torch.cat([translation, relative_rotation], dim=0)
52
+ pose_vecs.append(pose_vec)
53
+
54
+ if pose_vecs:
55
+ pose_sequence = torch.stack(pose_vecs, dim=0)
56
+
57
+ # 🔧 使用新的分析方法
58
+ analysis = classifier.analyze_pose_sequence(pose_sequence)
59
+ analysis['sample_name'] = item
60
+ all_analyses.append(analysis)
61
+
62
+ # 🔧 详细输出每个样本的分类信息
63
+ print(f"\n--- 样本 {sample_count + 1}: {item} ---")
64
+ print(f"总帧数: {analysis['total_frames']}")
65
+ print(f"总距离: {analysis['total_distance']:.4f}")
66
+
67
+ # 分类分布
68
+ class_dist = analysis['class_distribution']
69
+ print(f"分类分布:")
70
+ for class_name, count in class_dist.items():
71
+ percentage = count / analysis['total_frames'] * 100
72
+ print(f" {class_name}: {count} 帧 ({percentage:.1f}%)")
73
+
74
+ # 🔧 调试前几个pose的分类过程
75
+ print(f"前3帧的详细分类过程:")
76
+ for i in range(min(3, len(pose_vecs))):
77
+ debug_info = classifier.debug_single_pose(
78
+ pose_vecs[i][:3], pose_vecs[i][3:7]
79
+ )
80
+ print(f" 帧{i}: {debug_info['classification']} "
81
+ f"(yaw: {debug_info['yaw_angle_deg']:.2f}°, "
82
+ f"forward: {debug_info['forward_movement']:.3f})")
83
+
84
+ # 运动段落
85
+ print(f"运动段落:")
86
+ for i, segment in enumerate(analysis['motion_segments']):
87
+ print(f" 段落{i+1}: {segment['class']} (帧 {segment['start_frame']}-{segment['end_frame']}, 持续 {segment['duration']} 帧)")
88
+
89
+ # 🔧 确定主要运动类型
90
+ dominant_class = max(class_dist.items(), key=lambda x: x[1])
91
+ dominant_class_name = dominant_class[0]
92
+ dominant_percentage = dominant_class[1] / analysis['total_frames'] * 100
93
+
94
+ print(f"主要运动类型: {dominant_class_name} ({dominant_percentage:.1f}%)")
95
+
96
+ # 将样本添加到对应类别
97
+ class_samples[dominant_class_name].append({
98
+ 'name': item,
99
+ 'percentage': dominant_percentage,
100
+ 'analysis': analysis
101
+ })
102
+
103
+ sample_count += 1
104
+
105
+ except Exception as e:
106
+ print(f"❌ 处理样本 {item} 时出错: {e}")
107
+
108
+ print("\n" + "="*60)
109
+ print("=== 按类别分组的样本统计(基于相对于reference的变化)===")
110
+
111
+ # 🔧 按类别输出样本列表
112
+ for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
113
+ samples = class_samples[class_name]
114
+ print(f"\n🔸 {class_name.upper()} 类样本 (共 {len(samples)} 个):")
115
+
116
+ if samples:
117
+ # 按主要类别占比排序
118
+ samples.sort(key=lambda x: x['percentage'], reverse=True)
119
+
120
+ for i, sample_info in enumerate(samples, 1):
121
+ print(f" {i:2d}. {sample_info['name']} ({sample_info['percentage']:.1f}%)")
122
+
123
+ # 显示详细的段落信息
124
+ segments = sample_info['analysis']['motion_segments']
125
+ segment_summary = []
126
+ for seg in segments:
127
+ if seg['duration'] >= 2: # 只显示持续时间>=2帧的段落
128
+ segment_summary.append(f"{seg['class']}({seg['duration']})")
129
+
130
+ if segment_summary:
131
+ print(f" 段落: {' -> '.join(segment_summary)}")
132
+ else:
133
+ print(" (无样本)")
134
+
135
+ # 🔧 统计总体模式
136
+ print(f"\n" + "="*60)
137
+ print("=== 总体统计 ===")
138
+
139
+ total_forward = sum(a['class_distribution']['forward'] for a in all_analyses)
140
+ total_backward = sum(a['class_distribution']['backward'] for a in all_analyses)
141
+ total_left_turn = sum(a['class_distribution']['left_turn'] for a in all_analyses)
142
+ total_right_turn = sum(a['class_distribution']['right_turn'] for a in all_analyses)
143
+ total_frames = total_forward + total_backward + total_left_turn + total_right_turn
144
+
145
+ print(f"总样本数: {len(all_analyses)}")
146
+ print(f"总帧数: {total_frames}")
147
+ print(f"Forward: {total_forward} 帧 ({total_forward/total_frames*100:.1f}%)")
148
+ print(f"Backward: {total_backward} 帧 ({total_backward/total_frames*100:.1f}%)")
149
+ print(f"Left Turn: {total_left_turn} 帧 ({total_left_turn/total_frames*100:.1f}%)")
150
+ print(f"Right Turn: {total_right_turn} 帧 ({total_right_turn/total_frames*100:.1f}%)")
151
+
152
+ # 🔧 样本分布统计
153
+ print(f"\n按主要类型的样本分布:")
154
+ for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
155
+ count = len(class_samples[class_name])
156
+ percentage = count / len(all_analyses) * 100 if all_analyses else 0
157
+ print(f" {class_name}: {count} 样本 ({percentage:.1f}%)")
158
+
159
+ return all_analyses, class_samples
160
+
161
+ def calculate_relative_rotation(current_rotation, reference_rotation):
162
+ """计算相对旋转四元数"""
163
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
164
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
165
+
166
+ # 计算参考旋转的逆 (q_ref^-1)
167
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
168
+
169
+ # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
170
+ w1, x1, y1, z1 = q_ref_inv
171
+ w2, x2, y2, z2 = q_current
172
+
173
+ relative_rotation = torch.tensor([
174
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
175
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
176
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
177
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
178
+ ])
179
+
180
+ return relative_rotation
181
+
182
+ if __name__ == "__main__":
183
+ dataset_path = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_2"
184
+
185
+ print("开始详细分析pose分类(基于相对于reference的变化)...")
186
+ all_analyses, class_samples = analyze_turning_patterns_detailed(dataset_path, num_samples=4000)
187
+
188
+ print(f"\n🎉 分析完成! 共处理 {len(all_analyses)} 个样本")
scripts/batch_drone.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_drone_first"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_origin.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "exploring the world",
33
+ "--modality_type", "sekai",
34
+ "--direction", "right",
35
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt",
36
+ "--use_gt_prompt"
37
+ ]
38
+
39
+ # 仅使用第二张 GPU
40
+ env = os.environ.copy()
41
+ env["CUDA_VISIBLE_DEVICES"] = "0"
42
+
43
+ # 执行推理
44
+ subprocess.run(cmd, env=env)
scripts/batch_infer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import argparse
4
+ from pathlib import Path
5
+ import glob
6
+
7
+ def find_video_files(videos_dir):
8
+ """查找视频目录下的所有视频文件"""
9
+ video_extensions = ['.mp4']
10
+ video_files = []
11
+
12
+ for ext in video_extensions:
13
+ pattern = os.path.join(videos_dir, f"*{ext}")
14
+ video_files.extend(glob.glob(pattern))
15
+
16
+ return sorted(video_files)
17
+
18
+ def run_inference(condition_video, direction, dit_path, output_dir):
19
+ """运行单个推理任务"""
20
+ # 构建输出文件名
21
+ input_filename = os.path.basename(condition_video)
22
+ name_parts = os.path.splitext(input_filename)
23
+ output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
24
+ output_path = os.path.join(output_dir, output_filename)
25
+
26
+ # 构建推理命令
27
+ cmd = [
28
+ "python", "infer_nus.py",
29
+ "--condition_video", condition_video,
30
+ "--direction", direction,
31
+ "--dit_path", dit_path,
32
+ "--output_path", output_path,
33
+ ]
34
+
35
+ print(f"🎬 生成 {direction} 方向视频: {input_filename} -> {output_filename}")
36
+ print(f" 命令: {' '.join(cmd)}")
37
+
38
+ try:
39
+ # 运行推理
40
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
41
+ print(f"✅ 成功生成: {output_path}")
42
+ return True
43
+ except subprocess.CalledProcessError as e:
44
+ print(f"❌ 生成失败: {e}")
45
+ print(f" 错误输出: {e.stderr}")
46
+ return False
47
+
48
+ def batch_inference(args):
49
+ """批量推理主函数"""
50
+ videos_dir = args.videos_dir
51
+ output_dir = args.output_dir
52
+ directions = args.directions
53
+ dit_path = args.dit_path
54
+
55
+ # 检查输入目录
56
+ if not os.path.exists(videos_dir):
57
+ print(f"❌ 视频目录不存在: {videos_dir}")
58
+ return
59
+
60
+ # 创建输出目录
61
+ os.makedirs(output_dir, exist_ok=True)
62
+ print(f"📁 输出目录: {output_dir}")
63
+
64
+ # 查找所有视频文件
65
+ video_files = find_video_files(videos_dir)
66
+
67
+ if not video_files:
68
+ print(f"❌ 在 {videos_dir} 中没有找到视频文件")
69
+ return
70
+
71
+ print(f"🎥 找到 {len(video_files)} 个视频文件:")
72
+ for video in video_files:
73
+ print(f" - {os.path.basename(video)}")
74
+
75
+ print(f"🎯 将为每个视频生成以下方向: {', '.join(directions)}")
76
+ print(f"📊 总共将生成 {len(video_files) * len(directions)} 个视频")
77
+
78
+ # 统计信息
79
+ total_tasks = len(video_files) * len(directions)
80
+ completed_tasks = 0
81
+ failed_tasks = 0
82
+
83
+ # 批量处理
84
+ for i, video_file in enumerate(video_files, 1):
85
+ print(f"\n{'='*60}")
86
+ print(f"处理视频 {i}/{len(video_files)}: {os.path.basename(video_file)}")
87
+ print(f"{'='*60}")
88
+
89
+ for j, direction in enumerate(directions, 1):
90
+ print(f"\n--- 方向 {j}/{len(directions)}: {direction} ---")
91
+
92
+ # 检查输出文件是否已存在
93
+ input_filename = os.path.basename(video_file)
94
+ name_parts = os.path.splitext(input_filename)
95
+ output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
96
+ output_path = os.path.join(output_dir, output_filename)
97
+
98
+ if os.path.exists(output_path) and not args.overwrite:
99
+ print(f"⏭️ 文件已存在,跳过: {output_filename}")
100
+ completed_tasks += 1
101
+ continue
102
+
103
+ # 运行推理
104
+ success = run_inference(
105
+ condition_video=video_file,
106
+ direction=direction,
107
+ dit_path=dit_path,
108
+ output_dir=output_dir,
109
+ )
110
+
111
+ if success:
112
+ completed_tasks += 1
113
+ else:
114
+ failed_tasks += 1
115
+
116
+ # 显示进度
117
+ current_progress = completed_tasks + failed_tasks
118
+ print(f"📈 进度: {current_progress}/{total_tasks} "
119
+ f"(成功: {completed_tasks}, 失败: {failed_tasks})")
120
+
121
+ # 最终统计
122
+ print(f"\n{'='*60}")
123
+ print(f"🎉 批量推理完成!")
124
+ print(f"📊 总任务数: {total_tasks}")
125
+ print(f"✅ 成功: {completed_tasks}")
126
+ print(f"❌ 失败: {failed_tasks}")
127
+ print(f"📁 输出目录: {output_dir}")
128
+
129
+ if failed_tasks > 0:
130
+ print(f"⚠️ 有 {failed_tasks} 个任务失败,请检查日志")
131
+
132
+ # 列出生成的文件
133
+ if completed_tasks > 0:
134
+ print(f"\n📋 生成的文件:")
135
+ generated_files = glob.glob(os.path.join(output_dir, "*.mp4"))
136
+ for file_path in sorted(generated_files):
137
+ print(f" - {os.path.basename(file_path)}")
138
+
139
+ def main():
140
+ parser = argparse.ArgumentParser(description="批量对nus/videos目录下的所有视频生成不同方向的输出")
141
+
142
+ parser.add_argument("--videos_dir", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032",
143
+ help="输入视频目录路径")
144
+
145
+ parser.add_argument("--output_dir", type=str, default="nus/infer_results/batch_dynamic_4032_noise",
146
+ help="输出视频目录路径")
147
+
148
+ parser.add_argument("--directions", nargs="+",
149
+ default=["left_turn", "right_turn"],
150
+ choices=["forward", "backward", "left_turn", "right_turn"],
151
+ help="要生成的方向列表")
152
+
153
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
154
+ help="训练好的DiT模型路径")
155
+
156
+ parser.add_argument("--overwrite", action="store_true",
157
+ help="是否覆盖已存在的输出文件")
158
+
159
+ parser.add_argument("--dry_run", action="store_true",
160
+ help="只显示将要执行的任务,不实际运行")
161
+
162
+ args = parser.parse_args()
163
+
164
+ if args.dry_run:
165
+ print("🔍 预览模式 - 只显示任务,不执行")
166
+ videos_dir = args.videos_dir
167
+ video_files = find_video_files(videos_dir)
168
+
169
+ print(f"📁 输入目录: {videos_dir}")
170
+ print(f"📁 输出目录: {args.output_dir}")
171
+ print(f"🎥 找到视频: {len(video_files)} 个")
172
+ print(f"🎯 生成方向: {', '.join(args.directions)}")
173
+ print(f"📊 总任务数: {len(video_files) * len(args.directions)}")
174
+
175
+ print(f"\n将要执行的任务:")
176
+ for video in video_files:
177
+ for direction in args.directions:
178
+ input_name = os.path.basename(video)
179
+ name_parts = os.path.splitext(input_name)
180
+ output_name = f"{name_parts[0]}_{direction}{name_parts[1]}"
181
+ print(f" {input_name} -> {output_name} ({direction})")
182
+ else:
183
+ batch_inference(args)
184
+
185
+ if __name__ == "__main__":
186
+ main()
scripts/batch_nus.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_nus_right_2"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "a car is driving",
33
+ "--modality_type", "nuscenes",
34
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
35
+ ]
36
+
37
+ # 仅使用第二张 GPU
38
+ env = os.environ.copy()
39
+ env["CUDA_VISIBLE_DEVICES"] = "1"
40
+
41
+ # 执行推理
42
+ subprocess.run(cmd, env=env)
scripts/batch_rt.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_RT"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "A robotic arm is moving the object",
33
+ "--modality_type", "openx",
34
+ ]
35
+
36
+ # 仅使用第二张 GPU
37
+ env = os.environ.copy()
38
+ env["CUDA_VISIBLE_DEVICES"] = "1"
39
+
40
+ # 执行推理
41
+ subprocess.run(cmd, env=env)
scripts/batch_spa.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_right"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "exploring the world",
33
+ "--modality_type", "sekai",
34
+ #"--direction", "left",
35
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
36
+ ]
37
+
38
+ # 仅使用第二张 GPU
39
+ env = os.environ.copy()
40
+ env["CUDA_VISIBLE_DEVICES"] = "0"
41
+
42
+ # 执行推理
43
+ subprocess.run(cmd, env=env)
scripts/batch_walk.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_walk"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "a car is driving",
33
+ "--modality_type", "nuscenes",
34
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
35
+ ]
36
+
37
+ # 仅使用第二张 GPU
38
+ env = os.environ.copy()
39
+ env["CUDA_VISIBLE_DEVICES"] = "1"
40
+
41
+ # 执行推理
42
+ subprocess.run(cmd, env=env)
scripts/check.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import argparse
4
+ from collections import defaultdict
5
+ import time
6
+
7
+ def load_checkpoint(ckpt_path):
8
+ """加载检查点文件"""
9
+ if not os.path.exists(ckpt_path):
10
+ return None
11
+
12
+ try:
13
+ state_dict = torch.load(ckpt_path, map_location='cpu')
14
+ return state_dict
15
+ except Exception as e:
16
+ print(f"❌ 加载检查点失败: {e}")
17
+ return None
18
+
19
+ def compare_parameters(state_dict1, state_dict2, threshold=1e-8):
20
+ """比较两个状态字典的参数差异"""
21
+ if state_dict1 is None or state_dict2 is None:
22
+ return None
23
+
24
+ updated_params = {}
25
+ unchanged_params = {}
26
+
27
+ for name, param1 in state_dict1.items():
28
+ if name in state_dict2:
29
+ param2 = state_dict2[name]
30
+
31
+ # 计算参数差异
32
+ diff = torch.abs(param1 - param2)
33
+ max_diff = torch.max(diff).item()
34
+ mean_diff = torch.mean(diff).item()
35
+
36
+ if max_diff > threshold:
37
+ updated_params[name] = {
38
+ 'max_diff': max_diff,
39
+ 'mean_diff': mean_diff,
40
+ 'shape': param1.shape
41
+ }
42
+ else:
43
+ unchanged_params[name] = {
44
+ 'max_diff': max_diff,
45
+ 'mean_diff': mean_diff,
46
+ 'shape': param1.shape
47
+ }
48
+
49
+ return updated_params, unchanged_params
50
+
51
+ def categorize_parameters(param_dict):
52
+ """将参数按类型分类"""
53
+ categories = {
54
+ 'moe_related': {},
55
+ 'camera_related': {},
56
+ 'framepack_related': {},
57
+ 'attention': {},
58
+ 'other': {}
59
+ }
60
+
61
+ for name, info in param_dict.items():
62
+ if any(keyword in name.lower() for keyword in ['moe', 'gate', 'expert', 'processor']):
63
+ categories['moe_related'][name] = info
64
+ elif any(keyword in name.lower() for keyword in ['cam_encoder', 'projector', 'camera']):
65
+ categories['camera_related'][name] = info
66
+ elif any(keyword in name.lower() for keyword in ['clean_x_embedder', 'framepack']):
67
+ categories['framepack_related'][name] = info
68
+ elif any(keyword in name.lower() for keyword in ['attn', 'attention']):
69
+ categories['attention'][name] = info
70
+ else:
71
+ categories['other'][name] = info
72
+
73
+ return categories
74
+
75
+ def print_category_summary(category_name, params, color_code=''):
76
+ """打印某类参数的摘要"""
77
+ if not params:
78
+ print(f"{color_code} {category_name}: 无参数")
79
+ return
80
+
81
+ total_params = len(params)
82
+ max_diffs = [info['max_diff'] for info in params.values()]
83
+ mean_diffs = [info['mean_diff'] for info in params.values()]
84
+
85
+ print(f"{color_code} {category_name} ({total_params} 个参数):")
86
+ print(f" 最大差异范围: {min(max_diffs):.2e} ~ {max(max_diffs):.2e}")
87
+ print(f" 平均差异范围: {min(mean_diffs):.2e} ~ {max(mean_diffs):.2e}")
88
+
89
+ # 显示前5个最大变化的参数
90
+ sorted_params = sorted(params.items(), key=lambda x: x[1]['max_diff'], reverse=True)
91
+ print(f" 变化最大的参数:")
92
+ for i, (name, info) in enumerate(sorted_params[:100]):
93
+ shape_str = 'x'.join(map(str, info['shape']))
94
+ print(f" {i+1}. {name} [{shape_str}]: max_diff={info['max_diff']:.2e}")
95
+
96
+ def monitor_training(checkpoint_dir, check_interval=60):
97
+ """监控训练过程中的参数更新"""
98
+ print(f"🔍 开始监控训练进度...")
99
+ print(f"📁 检查点目录: {checkpoint_dir}")
100
+ print(f"⏰ 检查间隔: {check_interval}秒")
101
+ print("=" * 80)
102
+
103
+ previous_ckpt = None
104
+ previous_step = -1
105
+
106
+ while True:
107
+ try:
108
+ # 查找最新的检查点
109
+ if not os.path.exists(checkpoint_dir):
110
+ print(f"❌ 检查点目录不存在: {checkpoint_dir}")
111
+ time.sleep(check_interval)
112
+ continue
113
+
114
+ ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('step') and f.endswith('.ckpt')]
115
+ if not ckpt_files:
116
+ print("⏳ 未找到检查点文件,等待中...")
117
+ time.sleep(check_interval)
118
+ continue
119
+
120
+ # 按步数排序,获取最新的
121
+ ckpt_files.sort(key=lambda x: int(x.replace('step', '').replace('.ckpt', '')))
122
+ latest_ckpt_file = ckpt_files[-1]
123
+ latest_ckpt_path = os.path.join(checkpoint_dir, latest_ckpt_file)
124
+
125
+ # 提取步数
126
+ current_step = int(latest_ckpt_file.replace('step', '').replace('.ckpt', ''))
127
+
128
+ if current_step <= previous_step:
129
+ print(f"⏳ 等待新的检查点... (当前: step{current_step})")
130
+ time.sleep(check_interval)
131
+ continue
132
+
133
+ print(f"\n🔍 发现新检查点: {latest_ckpt_file}")
134
+
135
+ # 加载当前检查点
136
+ current_state_dict = load_checkpoint(latest_ckpt_path)
137
+ if current_state_dict is None:
138
+ print("❌ 无法加载当前检查点")
139
+ time.sleep(check_interval)
140
+ continue
141
+
142
+ if previous_ckpt is not None:
143
+ print(f"📊 比较 step{previous_step} -> step{current_step}")
144
+
145
+ # 比较参数
146
+ updated_params, unchanged_params = compare_parameters(
147
+ previous_ckpt, current_state_dict, threshold=1e-8
148
+ )
149
+
150
+ if updated_params is None:
151
+ print("❌ 参数比较失败")
152
+ else:
153
+ # 分类显示结果
154
+ updated_categories = categorize_parameters(updated_params)
155
+ unchanged_categories = categorize_parameters(unchanged_params)
156
+
157
+ print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
158
+ print_category_summary("MoE相关", updated_categories['moe_related'], '🔥')
159
+ print_category_summary("Camera相关", updated_categories['camera_related'], '📷')
160
+ print_category_summary("FramePack相关", updated_categories['framepack_related'], '🎞️')
161
+ print_category_summary("注意力相关", updated_categories['attention'], '👁️')
162
+ print_category_summary("其他", updated_categories['other'], '📦')
163
+
164
+ print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
165
+ print_category_summary("MoE相关", unchanged_categories['moe_related'], '❄️')
166
+ print_category_summary("Camera相关", unchanged_categories['camera_related'], '❄️')
167
+ print_category_summary("FramePack相关", unchanged_categories['framepack_related'], '❄️')
168
+ print_category_summary("注意力相关", unchanged_categories['attention'], '❄️')
169
+ print_category_summary("其他", unchanged_categories['other'], '❄️')
170
+
171
+ # 检查关键组件是否在更新
172
+ critical_keywords = ['moe', 'cam_encoder', 'projector', 'clean_x_embedder']
173
+ critical_updated = any(
174
+ any(keyword in name.lower() for keyword in critical_keywords)
175
+ for name in updated_params.keys()
176
+ )
177
+
178
+ if critical_updated:
179
+ print("\n✅ 关键组件正在更新!")
180
+ else:
181
+ print("\n❌ 警告:关键组件可能未在更新!")
182
+
183
+ # 计算更新率
184
+ total_params = len(updated_params) + len(unchanged_params)
185
+ update_rate = len(updated_params) / total_params * 100
186
+ print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
187
+
188
+ # 保存当前状态用于下次比较
189
+ previous_ckpt = current_state_dict
190
+ previous_step = current_step
191
+
192
+ print("=" * 80)
193
+ time.sleep(check_interval)
194
+
195
+ except KeyboardInterrupt:
196
+ print("\n👋 监控已停止")
197
+ break
198
+ except Exception as e:
199
+ print(f"❌ 监控过程中出错: {e}")
200
+ time.sleep(check_interval)
201
+
202
+ def compare_two_checkpoints(ckpt1_path, ckpt2_path):
203
+ """比较两个特定的检查点"""
204
+ print(f"🔍 比较两个检查点:")
205
+ print(f" 检查点1: {ckpt1_path}")
206
+ print(f" 检查点2: {ckpt2_path}")
207
+ print("=" * 80)
208
+
209
+ # 加载检查点
210
+ state_dict1 = load_checkpoint(ckpt1_path)
211
+ state_dict2 = load_checkpoint(ckpt2_path)
212
+
213
+ if state_dict1 is None or state_dict2 is None:
214
+ print("❌ 无法加载检查点文件")
215
+ return
216
+
217
+ # 比较参数
218
+ updated_params, unchanged_params = compare_parameters(state_dict1, state_dict2)
219
+
220
+ if updated_params is None:
221
+ print("❌ 参数比较失败")
222
+ return
223
+
224
+ # 分类显示结果
225
+ updated_categories = categorize_parameters(updated_params)
226
+ unchanged_categories = categorize_parameters(unchanged_params)
227
+
228
+ print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
229
+ for category_name, params in updated_categories.items():
230
+ print_category_summary(category_name.replace('_', ' ').title(), params, '🔥')
231
+
232
+ print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
233
+ for category_name, params in unchanged_categories.items():
234
+ print_category_summary(category_name.replace('_', ' ').title(), params, '❄️')
235
+
236
+ # 计算更新率
237
+ total_params = len(updated_params) + len(unchanged_params)
238
+ update_rate = len(updated_params) / total_params * 100
239
+ print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
240
+
241
+ if __name__ == '__main__':
242
+ parser = argparse.ArgumentParser(description="检查模型参数更新情况")
243
+ parser.add_argument("--checkpoint_dir", type=str,
244
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe",
245
+ help="检查点目录路径")
246
+ parser.add_argument("--compare", default=True,
247
+ help="比较两个特定检查点,而不是监控")
248
+ parser.add_argument("--ckpt1", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step1500_origin_cam_4.ckpt")
249
+ parser.add_argument("--ckpt2", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step500_origin_cam_4.ckpt")
250
+ parser.add_argument("--interval", type=int, default=60,
251
+ help="监控检查间隔(秒)")
252
+ parser.add_argument("--threshold", type=float, default=1e-8,
253
+ help="参数变化阈值")
254
+
255
+ args = parser.parse_args()
256
+
257
+ if args.compare:
258
+ if not args.ckpt1 or not args.ckpt2:
259
+ print("❌ 比较模式需要指定 --ckpt1 和 --ckpt2")
260
+ else:
261
+ compare_two_checkpoints(args.ckpt1, args.ckpt2)
262
+ else:
263
+ monitor_training(args.checkpoint_dir, args.interval)
scripts/decode_openx.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import argparse
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ class VideoDecoder:
12
+ def __init__(self, vae_path, device="cuda"):
13
+ """初始化视频解码器"""
14
+ self.device = device
15
+
16
+ # 初始化模型管理器
17
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
18
+ model_manager.load_models([vae_path])
19
+
20
+ # 创建pipeline并只保留VAE
21
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
22
+ self.pipe = self.pipe.to(device)
23
+
24
+ # 🔧 关键修复:确保VAE及其所有组件都在正确设备上
25
+ self.pipe.vae = self.pipe.vae.to(device)
26
+ if hasattr(self.pipe.vae, 'model'):
27
+ self.pipe.vae.model = self.pipe.vae.model.to(device)
28
+
29
+ print(f"✅ VAE解码器初始化完成,设备: {device}")
30
+
31
+ def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
32
+ """
33
+ 将latents解码为视频 - 修正版本,修复维度处理问题
34
+ """
35
+ print(f"🔧 开始解码latents...")
36
+ print(f"输入latents形状: {latents.shape}")
37
+ print(f"输入latents设备: {latents.device}")
38
+ print(f"输入latents数据类型: {latents.dtype}")
39
+
40
+ # 确保latents有batch维度
41
+ if len(latents.shape) == 4: # [C, T, H, W]
42
+ latents = latents.unsqueeze(0) # -> [1, C, T, H, W]
43
+
44
+ # 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
45
+ model_dtype = next(self.pipe.vae.parameters()).dtype
46
+ model_device = next(self.pipe.vae.parameters()).device
47
+
48
+ print(f"模型设备: {model_device}")
49
+ print(f"模型数据类型: {model_dtype}")
50
+
51
+ # 将latents移动到正确的设备和数据类型
52
+ latents = latents.to(device=model_device, dtype=model_dtype)
53
+
54
+ print(f"解码latents形状: {latents.shape}")
55
+ print(f"解码latents设备: {latents.device}")
56
+ print(f"解码latents数据类型: {latents.dtype}")
57
+
58
+ # 🔧 强制设置pipeline设备,确保所有操作在同一设备上
59
+ self.pipe.device = model_device
60
+
61
+ # 使用VAE解码
62
+ with torch.no_grad():
63
+ try:
64
+ if tiled:
65
+ print("🔧 尝试tiled解码...")
66
+ decoded_video = self.pipe.decode_video(
67
+ latents,
68
+ tiled=True,
69
+ tile_size=tile_size,
70
+ tile_stride=tile_stride
71
+ )
72
+ else:
73
+ print("🔧 使用非tiled解码...")
74
+ decoded_video = self.pipe.decode_video(latents, tiled=False)
75
+
76
+ except Exception as e:
77
+ print(f"decode_video失败,错误: {e}")
78
+ import traceback
79
+ traceback.print_exc()
80
+
81
+ # 🔧 fallback: 尝试直接调用VAE
82
+ try:
83
+ print("🔧 尝试直接调用VAE解码...")
84
+ decoded_video = self.pipe.vae.decode(
85
+ latents.squeeze(0), # 移除batch维度 [C, T, H, W]
86
+ device=model_device,
87
+ tiled=False
88
+ )
89
+ # 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
90
+ if len(decoded_video.shape) == 4: # [T, H, W, C]
91
+ decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C]
92
+ except Exception as e2:
93
+ print(f"直接VAE解码也失败: {e2}")
94
+ raise e2
95
+
96
+ print(f"解码后视频形状: {decoded_video.shape}")
97
+
98
+ # 🔧 关键修正:正确处理维度顺序
99
+ video_np = None
100
+
101
+ if len(decoded_video.shape) == 5:
102
+ # 检查不同的可能维度顺序
103
+ if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
104
+ # 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
105
+ print("🔧 检测到格式: [B, C, T, H, W]")
106
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
107
+ elif decoded_video.shape[1] == 3:
108
+ # 如果第二个维度是3,可能是 [B, C, T, H, W]
109
+ print("🔧 检测到可能的格式: [B, C, T, H, W]")
110
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
111
+ elif decoded_video.shape[-1] == 3:
112
+ # 如果最后一个维度是3,可能是 [B, T, H, W, C]
113
+ print("🔧 检测到格式: [B, T, H, W, C]")
114
+ video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C]
115
+ else:
116
+ # 尝试找到维度为3的位置
117
+ shape = list(decoded_video.shape)
118
+ if 3 in shape:
119
+ channel_dim = shape.index(3)
120
+ print(f"🔧 检测到通道维度在位置: {channel_dim}")
121
+
122
+ if channel_dim == 1: # [B, C, T, H, W]
123
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
124
+ elif channel_dim == 4: # [B, T, H, W, C]
125
+ video_np = decoded_video[0].to(torch.float32).cpu().numpy()
126
+ else:
127
+ print(f"⚠️ 未知的通道维度位置: {channel_dim}")
128
+ raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
129
+ else:
130
+ print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
131
+ raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
132
+
133
+ elif len(decoded_video.shape) == 4:
134
+ # 4维张量,检查可能的格式
135
+ if decoded_video.shape[-1] == 3: # [T, H, W, C]
136
+ video_np = decoded_video.to(torch.float32).cpu().numpy()
137
+ elif decoded_video.shape[0] == 3: # [C, T, H, W]
138
+ video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
139
+ else:
140
+ print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
141
+ raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
142
+ else:
143
+ print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
144
+ raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
145
+
146
+ if video_np is None:
147
+ raise ValueError("Failed to convert video tensor to numpy array")
148
+
149
+ print(f"转换后视频数组形状: {video_np.shape}")
150
+
151
+ # 🔧 验证最终形状
152
+ if len(video_np.shape) != 4:
153
+ raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
154
+
155
+ if video_np.shape[-1] != 3:
156
+ print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
157
+ print(f"完整形状: {video_np.shape}")
158
+ # 尝试其他维度排列
159
+ if video_np.shape[0] == 3: # [C, T, H, W]
160
+ print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
161
+ video_np = np.transpose(video_np, (1, 2, 3, 0))
162
+ elif video_np.shape[1] == 3: # [T, C, H, W]
163
+ print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
164
+ video_np = np.transpose(video_np, (0, 2, 3, 1))
165
+ else:
166
+ raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
167
+
168
+ # 反归一化
169
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化
170
+ video_np = (video_np * 255).astype(np.uint8)
171
+
172
+ print(f"最终视频数组形状: {video_np.shape}")
173
+ print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
174
+
175
+ # 保存视频
176
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
177
+
178
+ try:
179
+ with imageio.get_writer(output_path, fps=10, quality=8) as writer:
180
+ for frame_idx, frame in enumerate(video_np):
181
+ # 🔧 验证每一帧的形状
182
+ if len(frame.shape) != 3 or frame.shape[-1] != 3:
183
+ print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
184
+ continue
185
+
186
+ writer.append_data(frame)
187
+ if frame_idx % 10 == 0:
188
+ print(f" 写入帧 {frame_idx}/{len(video_np)}")
189
+ except Exception as e:
190
+ print(f"保存视频失败: {e}")
191
+ # 🔧 尝试保存前几帧为图片进行调试
192
+ debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
193
+ os.makedirs(debug_dir, exist_ok=True)
194
+
195
+ for i in range(min(5, len(video_np))):
196
+ frame = video_np[i]
197
+ debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
198
+ try:
199
+ if len(frame.shape) == 3 and frame.shape[-1] == 3:
200
+ Image.fromarray(frame).save(debug_path)
201
+ print(f"调试: 保存帧 {i} 到 {debug_path}")
202
+ else:
203
+ print(f"调试: 帧 {i} 形状异常: {frame.shape}")
204
+ except Exception as e2:
205
+ print(f"调试: 保存帧 {i} 失败: {e2}")
206
+ raise e
207
+
208
+ print(f"✅ 视频保存到: {output_path}")
209
+ return video_np
210
+
211
+ def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
212
+ """将视频帧保存为单独的图像文件"""
213
+ os.makedirs(output_dir, exist_ok=True)
214
+
215
+ for i, frame in enumerate(video_np):
216
+ frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
217
+ # 🔧 验证帧形状
218
+ if len(frame.shape) == 3 and frame.shape[-1] == 3:
219
+ Image.fromarray(frame).save(frame_path)
220
+ else:
221
+ print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
222
+
223
+ print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")
224
+
225
+ def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
226
+ """解码单个episode的编码数据 - 修正版本"""
227
+ print(f"\n🔧 解码episode: {encoded_pth_path}")
228
+
229
+ # 加载编码数据
230
+ try:
231
+ encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
232
+ print(f"✅ 成功加载编码数据")
233
+ except Exception as e:
234
+ print(f"❌ 加载编码数据失败: {e}")
235
+ return False
236
+
237
+ # 检查数据结构
238
+ print("🔍 编码数据结构:")
239
+ for key, value in encoded_data.items():
240
+ if isinstance(value, torch.Tensor):
241
+ print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
242
+ elif isinstance(value, dict):
243
+ print(f" - {key}: dict with keys {list(value.keys())}")
244
+ else:
245
+ print(f" - {key}: {type(value)}")
246
+
247
+ # 获取latents
248
+ latents = encoded_data.get('latents')
249
+ if latents is None:
250
+ print("❌ 未找到latents数据")
251
+ return False
252
+
253
+ # 🔧 确保latents在CPU上(加载时的默认状态)
254
+ if latents.device != torch.device('cpu'):
255
+ latents = latents.cpu()
256
+ print(f"🔧 将latents移动到CPU: {latents.device}")
257
+
258
+ episode_info = encoded_data.get('episode_info', {})
259
+ episode_idx = episode_info.get('episode_idx', 'unknown')
260
+ total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数
261
+
262
+ print(f"Episode信息:")
263
+ print(f" - Episode索引: {episode_idx}")
264
+ print(f" - Latents形状: {latents.shape}")
265
+ print(f" - Latents设备: {latents.device}")
266
+ print(f" - Latents数据类型: {latents.dtype}")
267
+ print(f" - 原始总帧数: {total_frames}")
268
+ print(f" - 压缩后帧数: {latents.shape[1]}")
269
+
270
+ # 创建输出目录
271
+ episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
272
+ output_dir = os.path.join(output_base_dir, episode_name)
273
+ os.makedirs(output_dir, exist_ok=True)
274
+
275
+ # 初始化解码器
276
+ try:
277
+ decoder = VideoDecoder(vae_path, device)
278
+ except Exception as e:
279
+ print(f"❌ 初始化解码器失败: {e}")
280
+ return False
281
+
282
+ # 解码为视频
283
+ video_output_path = os.path.join(output_dir, "decoded_video.mp4")
284
+ try:
285
+ video_np = decoder.decode_latents_to_video(
286
+ latents,
287
+ video_output_path,
288
+ tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性
289
+ tile_size=(34, 34),
290
+ tile_stride=(18, 16)
291
+ )
292
+
293
+ # 保存前几帧为图像(用于快速检查)
294
+ frames_dir = os.path.join(output_dir, "frames")
295
+ sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧
296
+ decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
297
+
298
+ # 保存解码信息
299
+ decode_info = {
300
+ "source_pth": encoded_pth_path,
301
+ "decoded_video_path": video_output_path,
302
+ "latents_shape": list(latents.shape),
303
+ "decoded_video_shape": list(video_np.shape),
304
+ "original_total_frames": total_frames,
305
+ "decoded_frames": len(video_np),
306
+ "compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
307
+ "latents_dtype": str(latents.dtype),
308
+ "latents_device": str(latents.device),
309
+ "vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
310
+ }
311
+
312
+ info_path = os.path.join(output_dir, "decode_info.json")
313
+ with open(info_path, 'w') as f:
314
+ json.dump(decode_info, f, indent=2)
315
+
316
+ print(f"✅ Episode {episode_idx} 解码完成")
317
+ print(f" - 原始帧数: {total_frames}")
318
+ print(f" - 解码帧数: {len(video_np)}")
319
+ print(f" - 压缩比: {decode_info['compression_ratio']:.2f}")
320
+ print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
321
+ return True
322
+
323
+ except Exception as e:
324
+ print(f"❌ 解码失败: {e}")
325
+ import traceback
326
+ traceback.print_exc()
327
+ return False
328
+
329
+ def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
330
+ """批量解码episodes"""
331
+ print(f"🔧 批量解码Open-X episodes")
332
+ print(f"源目录: {encoded_base_dir}")
333
+ print(f"输出目录: {output_base_dir}")
334
+
335
+ # 查找所有编码的episodes
336
+ episode_dirs = []
337
+ if os.path.exists(encoded_base_dir):
338
+ for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性
339
+ episode_dir = os.path.join(encoded_base_dir, item)
340
+ if os.path.isdir(episode_dir):
341
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
342
+ if os.path.exists(encoded_path):
343
+ episode_dirs.append(encoded_path)
344
+
345
+ print(f"找到 {len(episode_dirs)} 个编码的episodes")
346
+
347
+ if max_episodes and len(episode_dirs) > max_episodes:
348
+ episode_dirs = episode_dirs[:max_episodes]
349
+ print(f"限制处理前 {max_episodes} 个episodes")
350
+
351
+ # 批量解码
352
+ success_count = 0
353
+ for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
354
+ print(f"\n{'='*60}")
355
+ print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
356
+
357
+ success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
358
+ if success:
359
+ success_count += 1
360
+
361
+ print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
362
+
363
+ print(f"\n🎉 批量解码完成!")
364
+ print(f"总处理: {len(episode_dirs)} 个episodes")
365
+ print(f"成功解码: {success_count} 个episodes")
366
+ print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")
367
+
368
+ def main():
369
+ parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
370
+ parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
371
+ help="解码模式:single (单个episode) 或 batch (批量)")
372
+ parser.add_argument("--encoded_pth", type=str,
373
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
374
+ help="单个编码文件路径(single模式)")
375
+ parser.add_argument("--encoded_base_dir", type=str,
376
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
377
+ help="编码数据基础目录(batch模式)")
378
+ parser.add_argument("--vae_path", type=str,
379
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
380
+ help="VAE模型路径")
381
+ parser.add_argument("--output_dir", type=str,
382
+ default="./decoded_results_fixed",
383
+ help="解码输出目录")
384
+ parser.add_argument("--max_episodes", type=int, default=5,
385
+ help="最大解码episodes数量(batch模式,用于测试)")
386
+ parser.add_argument("--device", type=str, default="cuda",
387
+ help="计算设备")
388
+
389
+ args = parser.parse_args()
390
+
391
+ print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
392
+ print(f"模式: {args.mode}")
393
+ print(f"VAE路径: {args.vae_path}")
394
+ print(f"输出目录: {args.output_dir}")
395
+ print(f"设备: {args.device}")
396
+
397
+ # 🔧 检查CUDA可用性
398
+ if args.device == "cuda" and not torch.cuda.is_available():
399
+ print("⚠️ CUDA不可用,切换到CPU")
400
+ args.device = "cpu"
401
+
402
+ # 确保输出目录存在
403
+ os.makedirs(args.output_dir, exist_ok=True)
404
+
405
+ if args.mode == "single":
406
+ print(f"输入文件: {args.encoded_pth}")
407
+ if not os.path.exists(args.encoded_pth):
408
+ print(f"❌ 输入文件不存在: {args.encoded_pth}")
409
+ return
410
+
411
+ success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
412
+ if success:
413
+ print("✅ 单个episode解码成功")
414
+ else:
415
+ print("❌ 单个episode解码失败")
416
+
417
+ elif args.mode == "batch":
418
+ print(f"输入目录: {args.encoded_base_dir}")
419
+ print(f"最大episodes: {args.max_episodes}")
420
+
421
+ if not os.path.exists(args.encoded_base_dir):
422
+ print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
423
+ return
424
+
425
+ batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)
426
+
427
+ if __name__ == "__main__":
428
+ main()
scripts/download_recam.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ snapshot_download(
4
+ repo_id="KwaiVGI/ReCamMaster-Wan2.1",
5
+ local_dir="models/ReCamMaster/checkpoints",
6
+ resume_download=True # 支持断点续传
7
+ )
scripts/download_wan2.1.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from modelscope import snapshot_download
2
+
3
+
4
+ # Download models
5
+ snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
scripts/encode_dynamic_videos.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ from tqdm import tqdm
12
+ class VideoEncoder(pl.LightningModule):
13
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
14
+ super().__init__()
15
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
16
+ model_manager.load_models([text_encoder_path, vae_path])
17
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
18
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
19
+
20
+ self.frame_process = v2.Compose([
21
+ # v2.CenterCrop(size=(900, 1600)),
22
+ # v2.Resize(size=(900, 1600), antialias=True),
23
+ v2.ToTensor(),
24
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
25
+ ])
26
+
27
+ def crop_and_resize(self, image):
28
+ width, height = image.size
29
+ width_ori, height_ori_ = 832 , 480
30
+ image = v2.functional.resize(
31
+ image,
32
+ (round(height_ori_), round(width_ori)),
33
+ interpolation=v2.InterpolationMode.BILINEAR
34
+ )
35
+ return image
36
+
37
+ def load_video_frames(self, video_path):
38
+ """加载完整视频"""
39
+ reader = imageio.get_reader(video_path)
40
+ frames = []
41
+
42
+ for frame_data in reader:
43
+ frame = Image.fromarray(frame_data)
44
+ frame = self.crop_and_resize(frame)
45
+ frame = self.frame_process(frame)
46
+ frames.append(frame)
47
+
48
+ reader.close()
49
+
50
+ if len(frames) == 0:
51
+ return None
52
+
53
+ frames = torch.stack(frames, dim=0)
54
+ frames = rearrange(frames, "T C H W -> C T H W")
55
+ return frames
56
+
57
+ def encode_scenes(scenes_path, text_encoder_path, vae_path):
58
+ """编码所有场景的视频"""
59
+ encoder = VideoEncoder(text_encoder_path, vae_path)
60
+ encoder = encoder.cuda()
61
+ encoder.pipe.device = "cuda"
62
+
63
+ processed_count = 0
64
+
65
+ for idx, scene_name in enumerate(tqdm(os.listdir(scenes_path))):
66
+ if idx < 450:
67
+ continue
68
+ scene_dir = os.path.join(scenes_path, scene_name)
69
+ if not os.path.isdir(scene_dir):
70
+ continue
71
+
72
+ # 检查是否已编码
73
+ encoded_path = os.path.join(scene_dir, "encoded_video-480p-1.pth")
74
+ if os.path.exists(encoded_path):
75
+ print(f"Scene {scene_name} already encoded, skipping...")
76
+ continue
77
+
78
+ # 加载场景信息
79
+ scene_info_path = os.path.join(scene_dir, "scene_info.json")
80
+ if not os.path.exists(scene_info_path):
81
+ continue
82
+
83
+ with open(scene_info_path, 'r') as f:
84
+ scene_info = json.load(f)
85
+
86
+ # 加载视频
87
+ video_path = os.path.join(scene_dir, scene_info['video_path'])
88
+ if not os.path.exists(video_path):
89
+ print(f"Video not found: {video_path}")
90
+ continue
91
+
92
+ try:
93
+ print(f"Encoding scene {scene_name}...")
94
+
95
+ # 加载和编码视频
96
+ video_frames = encoder.load_video_frames(video_path)
97
+ if video_frames is None:
98
+ print(f"Failed to load video: {video_path}")
99
+ continue
100
+
101
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
102
+
103
+ # 编码视频
104
+ with torch.no_grad():
105
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
106
+ # print(latents.shape)
107
+ # assert False
108
+ # 编码文本
109
+ # prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
110
+ if processed_count == 0:
111
+ print('encode prompt!!!')
112
+ prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
113
+ del encoder.pipe.prompter
114
+
115
+ # 保存编码结果
116
+ encoded_data = {
117
+ "latents": latents.cpu(),
118
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
119
+ "image_emb": {}
120
+ }
121
+
122
+ torch.save(encoded_data, encoded_path)
123
+ print(f"Saved encoded data: {encoded_path}")
124
+ processed_count += 1
125
+
126
+ except Exception as e:
127
+ print(f"Error encoding scene {scene_name}: {e}")
128
+ continue
129
+
130
+ print(f"Encoding completed! Processed {processed_count} scenes.")
131
+
132
+ if __name__ == "__main__":
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes")
135
+ parser.add_argument("--text_encoder_path", type=str,
136
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
137
+ parser.add_argument("--vae_path", type=str,
138
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
139
+
140
+ args = parser.parse_args()
141
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path)
scripts/encode_openx.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ # 🔧 关键修复:设置环境变量避免GCS连接
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+ os.environ["TFDS_DISABLE_GCS"] = "1"
17
+
18
+ import tensorflow_datasets as tfds
19
+ import tensorflow as tf
20
+
21
+ class VideoEncoder(pl.LightningModule):
22
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
23
+ super().__init__()
24
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
25
+ model_manager.load_models([text_encoder_path, vae_path])
26
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
27
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
28
+
29
+ self.frame_process = v2.Compose([
30
+ v2.ToTensor(),
31
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
32
+ ])
33
+
34
+ def crop_and_resize(self, image, target_width=832, target_height=480):
35
+ """调整图像尺寸"""
36
+ image = v2.functional.resize(
37
+ image,
38
+ (target_height, target_width),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_episode_frames(self, episode_data, max_frames=300):
44
+ """🔧 从fractal数据集加载视频帧 - 基于实际observation字段优化"""
45
+ frames = []
46
+
47
+ steps = episode_data['steps']
48
+ frame_count = 0
49
+
50
+ print(f"开始提取帧,最多 {max_frames} 帧...")
51
+
52
+ for step_idx, step in enumerate(steps):
53
+ if frame_count >= max_frames:
54
+ break
55
+
56
+ try:
57
+ obs = step['observation']
58
+
59
+ # 🔧 基于实际的observation字段,优先使用'image'
60
+ img_data = None
61
+ image_keys_to_try = [
62
+ 'image', # ✅ 确认存在的主要图像字段
63
+ 'rgb', # 备用RGB图像
64
+ 'camera_image', # 备用相机图像
65
+ 'exterior_image_1_left', # 可能的外部摄像头
66
+ 'wrist_image', # 可能的手腕摄像头
67
+ ]
68
+
69
+ for img_key in image_keys_to_try:
70
+ if img_key in obs:
71
+ try:
72
+ img_tensor = obs[img_key]
73
+ img_data = img_tensor.numpy()
74
+ if step_idx < 3: # 只为前几个步骤打印
75
+ print(f"✅ 找到图像字段: {img_key}, 形状: {img_data.shape}")
76
+ break
77
+ except Exception as e:
78
+ if step_idx < 3:
79
+ print(f"尝试字段 {img_key} 失败: {e}")
80
+ continue
81
+
82
+ if img_data is not None:
83
+ # 确保图像数据格式正确
84
+ if len(img_data.shape) == 3: # [H, W, C]
85
+ if img_data.dtype == np.uint8:
86
+ frame = Image.fromarray(img_data)
87
+ else:
88
+ # 如果是归一化的浮点数,转换为uint8
89
+ if img_data.max() <= 1.0:
90
+ img_data = (img_data * 255).astype(np.uint8)
91
+ else:
92
+ img_data = img_data.astype(np.uint8)
93
+ frame = Image.fromarray(img_data)
94
+
95
+ # 转换为RGB如果需要
96
+ if frame.mode != 'RGB':
97
+ frame = frame.convert('RGB')
98
+
99
+ frame = self.crop_and_resize(frame)
100
+ frame = self.frame_process(frame)
101
+ frames.append(frame)
102
+ frame_count += 1
103
+
104
+ if frame_count % 50 == 0:
105
+ print(f"已处理 {frame_count} 帧")
106
+ else:
107
+ if step_idx < 5:
108
+ print(f"步骤 {step_idx}: 图像形状不正确 {img_data.shape}")
109
+ else:
110
+ # 如果找不到图像,打印可用的观测键
111
+ if step_idx < 5: # 只为前几个步骤打印
112
+ available_keys = list(obs.keys())
113
+ print(f"步骤 {step_idx}: 未找到图像,可用键: {available_keys}")
114
+
115
+ except Exception as e:
116
+ print(f"处理步骤 {step_idx} 时出错: {e}")
117
+ continue
118
+
119
+ print(f"成功提取 {len(frames)} 帧")
120
+
121
+ if len(frames) == 0:
122
+ return None
123
+
124
+ frames = torch.stack(frames, dim=0)
125
+ frames = rearrange(frames, "T C H W -> C T H W")
126
+ return frames
127
+
128
+ def extract_camera_poses(self, episode_data, num_frames):
129
+ """🔧 从fractal数据集提取相机位姿信息 - 基于实际observation和action字段优化"""
130
+ camera_poses = []
131
+
132
+ steps = episode_data['steps']
133
+ frame_count = 0
134
+
135
+ print("提取相机位姿信息...")
136
+
137
+ # 🔧 累积位姿信息
138
+ cumulative_translation = np.array([0.0, 0.0, 0.0], dtype=np.float32)
139
+ cumulative_rotation = np.array([0.0, 0.0, 0.0], dtype=np.float32) # 欧拉角
140
+
141
+ for step_idx, step in enumerate(steps):
142
+ if frame_count >= num_frames:
143
+ break
144
+
145
+ try:
146
+ obs = step['observation']
147
+ action = step.get('action', {})
148
+
149
+ # 🔧 基于实际的字段提取位姿变化
150
+ pose_data = {}
151
+ found_pose = False
152
+
153
+ # 1. 优先使用action中的world_vector(世界坐标系中的位移)
154
+ if 'world_vector' in action:
155
+ try:
156
+ world_vector = action['world_vector'].numpy()
157
+ if len(world_vector) == 3:
158
+ # 累积世界坐标位移
159
+ cumulative_translation += world_vector
160
+ pose_data['translation'] = cumulative_translation.copy()
161
+ found_pose = True
162
+
163
+ if step_idx < 3:
164
+ print(f"使用action.world_vector: {world_vector}, 累积位移: {cumulative_translation}")
165
+ except Exception as e:
166
+ if step_idx < 3:
167
+ print(f"action.world_vector提取失败: {e}")
168
+
169
+ # 2. 使用action中的rotation_delta(旋转变化)
170
+ if 'rotation_delta' in action:
171
+ try:
172
+ rotation_delta = action['rotation_delta'].numpy()
173
+ if len(rotation_delta) == 3:
174
+ # 累积旋转变化
175
+ cumulative_rotation += rotation_delta
176
+
177
+ # 转换为四元数(简化版本)
178
+ euler_angles = cumulative_rotation
179
+ # 欧拉角转四元数(ZYX顺序)
180
+ roll, pitch, yaw = euler_angles[0], euler_angles[1], euler_angles[2]
181
+
182
+ # 简化的欧拉角到四元数转换
183
+ cy = np.cos(yaw * 0.5)
184
+ sy = np.sin(yaw * 0.5)
185
+ cp = np.cos(pitch * 0.5)
186
+ sp = np.sin(pitch * 0.5)
187
+ cr = np.cos(roll * 0.5)
188
+ sr = np.sin(roll * 0.5)
189
+
190
+ qw = cr * cp * cy + sr * sp * sy
191
+ qx = sr * cp * cy - cr * sp * sy
192
+ qy = cr * sp * cy + sr * cp * sy
193
+ qz = cr * cp * sy - sr * sp * cy
194
+
195
+ pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
196
+ found_pose = True
197
+
198
+ if step_idx < 3:
199
+ print(f"使用action.rotation_delta: {rotation_delta}, 累积旋转: {cumulative_rotation}")
200
+ except Exception as e:
201
+ if step_idx < 3:
202
+ print(f"action.rotation_delta提取失败: {e}")
203
+
204
+ # 确保rotation字段存在
205
+ if 'rotation' not in pose_data:
206
+ # 使用当前累积的旋转计算四元数
207
+ roll, pitch, yaw = cumulative_rotation[0], cumulative_rotation[1], cumulative_rotation[2]
208
+
209
+ cy = np.cos(yaw * 0.5)
210
+ sy = np.sin(yaw * 0.5)
211
+ cp = np.cos(pitch * 0.5)
212
+ sp = np.sin(pitch * 0.5)
213
+ cr = np.cos(roll * 0.5)
214
+ sr = np.sin(roll * 0.5)
215
+
216
+ qw = cr * cp * cy + sr * sp * sy
217
+ qx = sr * cp * cy - cr * sp * sy
218
+ qy = cr * sp * cy + sr * cp * sy
219
+ qz = cr * cp * sy - sr * sp * cy
220
+
221
+ pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
222
+
223
+ camera_poses.append(pose_data)
224
+ frame_count += 1
225
+
226
+ except Exception as e:
227
+ print(f"提取位姿步骤 {step_idx} 时出错: {e}")
228
+ # 添加默认位姿
229
+ pose_data = {
230
+ 'translation': cumulative_translation.copy(),
231
+ 'rotation': np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
232
+ }
233
+ camera_poses.append(pose_data)
234
+ frame_count += 1
235
+
236
+ print(f"提取了 {len(camera_poses)} 个位姿")
237
+ print(f"最终累积位移: {cumulative_translation}")
238
+ print(f"最终累积旋转: {cumulative_rotation}")
239
+
240
+ return camera_poses
241
+
242
+ def create_camera_matrices(self, camera_poses):
243
+ """将位姿转换为4x4变换矩阵"""
244
+ matrices = []
245
+
246
+ for pose in camera_poses:
247
+ matrix = np.eye(4, dtype=np.float32)
248
+
249
+ # 设置平移
250
+ matrix[:3, 3] = pose['translation']
251
+
252
+ # 设置旋转 - 假设是四元数 [w, x, y, z]
253
+ if len(pose['rotation']) == 4:
254
+ # 四元数转旋转矩阵
255
+ q = pose['rotation']
256
+ w, x, y, z = q[0], q[1], q[2], q[3]
257
+
258
+ # 四元数到旋转矩阵的转换
259
+ matrix[0, 0] = 1 - 2*(y*y + z*z)
260
+ matrix[0, 1] = 2*(x*y - w*z)
261
+ matrix[0, 2] = 2*(x*z + w*y)
262
+ matrix[1, 0] = 2*(x*y + w*z)
263
+ matrix[1, 1] = 1 - 2*(x*x + z*z)
264
+ matrix[1, 2] = 2*(y*z - w*x)
265
+ matrix[2, 0] = 2*(x*z - w*y)
266
+ matrix[2, 1] = 2*(y*z + w*x)
267
+ matrix[2, 2] = 1 - 2*(x*x + y*y)
268
+ elif len(pose['rotation']) == 3:
269
+ # 欧拉角转换(如果需要)
270
+ pass
271
+
272
+ matrices.append(matrix)
273
+
274
+ return np.array(matrices)
275
+
276
+ def encode_fractal_dataset(dataset_path, text_encoder_path, vae_path, output_dir, max_episodes=None):
277
+ """🔧 编码fractal20220817_data数据集 - 基于实际字段结构优化"""
278
+
279
+ encoder = VideoEncoder(text_encoder_path, vae_path)
280
+ encoder = encoder.cuda()
281
+ encoder.pipe.device = "cuda"
282
+
283
+ os.makedirs(output_dir, exist_ok=True)
284
+
285
+ processed_count = 0
286
+ prompt_emb = None
287
+
288
+ try:
289
+ # 🔧 使用你提供的成功方法加载数据集
290
+ ds = tfds.load(
291
+ "fractal20220817_data",
292
+ split="train",
293
+ data_dir=dataset_path,
294
+ )
295
+
296
+ print(f"✅ 成功加载fractal20220817_data数据集")
297
+
298
+ # 限制处理的episode数量
299
+ if max_episodes:
300
+ ds = ds.take(max_episodes)
301
+ print(f"限制处理episodes数量: {max_episodes}")
302
+
303
+ except Exception as e:
304
+ print(f"❌ 加载数据集失败: {e}")
305
+ return
306
+
307
+ for episode_idx, episode in enumerate(tqdm(ds, desc="处理episodes")):
308
+ try:
309
+ episode_name = f"episode_{episode_idx:06d}"
310
+ save_episode_dir = os.path.join(output_dir, episode_name)
311
+
312
+ # 检查是否已经处理过
313
+ encoded_path = os.path.join(save_episode_dir, "encoded_video.pth")
314
+ if os.path.exists(encoded_path):
315
+ print(f"Episode {episode_name} 已处理,跳过...")
316
+ processed_count += 1
317
+ continue
318
+
319
+ os.makedirs(save_episode_dir, exist_ok=True)
320
+
321
+ print(f"\n🔧 处理episode {episode_name}...")
322
+
323
+ # 🔧 分析episode结构(仅对前几个episode)
324
+ if episode_idx < 2:
325
+ print("Episode结构分析:")
326
+ for key in episode.keys():
327
+ print(f" - {key}: {type(episode[key])}")
328
+
329
+ # 分析第一个step的结构
330
+ steps = episode['steps']
331
+ for step in steps.take(1):
332
+ print("第一个step结构:")
333
+ for key in step.keys():
334
+ print(f" - {key}: {type(step[key])}")
335
+
336
+ if 'observation' in step:
337
+ obs = step['observation']
338
+ print(" observation键:")
339
+ print(f" 🔍 可用字段: {list(obs.keys())}")
340
+
341
+ # 重点检查图像和位姿相关字段
342
+ key_fields = ['image', 'vector_to_go', 'rotation_delta_to_go', 'base_pose_tool_reached']
343
+ for key in key_fields:
344
+ if key in obs:
345
+ try:
346
+ value = obs[key]
347
+ if hasattr(value, 'shape'):
348
+ print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
349
+ else:
350
+ print(f" ✅ {key}: {type(value)}")
351
+ except Exception as e:
352
+ print(f" ❌ {key}: 无法访问 ({e})")
353
+
354
+ if 'action' in step:
355
+ action = step['action']
356
+ print(" action键:")
357
+ print(f" 🔍 可用字段: {list(action.keys())}")
358
+
359
+ # 重点检查位姿相关字段
360
+ key_fields = ['world_vector', 'rotation_delta', 'base_displacement_vector']
361
+ for key in key_fields:
362
+ if key in action:
363
+ try:
364
+ value = action[key]
365
+ if hasattr(value, 'shape'):
366
+ print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
367
+ else:
368
+ print(f" ✅ {key}: {type(value)}")
369
+ except Exception as e:
370
+ print(f" ❌ {key}: 无法访问 ({e})")
371
+
372
+ # 加载视频帧
373
+ video_frames = encoder.load_episode_frames(episode)
374
+ if video_frames is None:
375
+ print(f"❌ 无法加载episode {episode_name}的视频帧")
376
+ continue
377
+
378
+ print(f"✅ Episode {episode_name} 视频形状: {video_frames.shape}")
379
+
380
+ # 提取相机位姿
381
+ num_frames = video_frames.shape[1]
382
+ camera_poses = encoder.extract_camera_poses(episode, num_frames)
383
+ camera_matrices = encoder.create_camera_matrices(camera_poses)
384
+
385
+ print(f"🔧 编码episode {episode_name}...")
386
+
387
+ # 准备相机数据
388
+ cam_emb = {
389
+ 'extrinsic': camera_matrices,
390
+ 'intrinsic': np.eye(3, dtype=np.float32)
391
+ }
392
+
393
+ # 编码视频
394
+ frames_batch = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
395
+
396
+ with torch.no_grad():
397
+ latents = encoder.pipe.encode_video(frames_batch, **encoder.tiler_kwargs)[0]
398
+
399
+ # 编码文本prompt(第一次)
400
+ if prompt_emb is None:
401
+ print('🔧 编码prompt...')
402
+ prompt_emb = encoder.pipe.encode_prompt(
403
+ "A video of robotic manipulation task with camera movement"
404
+ )
405
+ # 释放prompter以节省内存
406
+ del encoder.pipe.prompter
407
+
408
+ # 保存编码结果
409
+ encoded_data = {
410
+ "latents": latents.cpu(),
411
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v
412
+ for k, v in prompt_emb.items()},
413
+ "cam_emb": cam_emb,
414
+ "episode_info": {
415
+ "episode_idx": episode_idx,
416
+ "total_frames": video_frames.shape[1],
417
+ "pose_extraction_method": "observation_action_based"
418
+ }
419
+ }
420
+
421
+ torch.save(encoded_data, encoded_path)
422
+ print(f"✅ 保存编码数据: {encoded_path}")
423
+
424
+ processed_count += 1
425
+ print(f"✅ 已处理 {processed_count} 个episodes")
426
+
427
+ except Exception as e:
428
+ print(f"❌ 处理episode {episode_idx}时出错: {e}")
429
+ import traceback
430
+ traceback.print_exc()
431
+ continue
432
+
433
+ print(f"🎉 编码完成! 总共处理了 {processed_count} 个episodes")
434
+ if __name__ == "__main__":
435
+ parser = argparse.ArgumentParser(description="Encode Open-X Fractal20220817 Dataset - Based on Real Structure")
436
+ parser.add_argument("--dataset_path", type=str,
437
+ default="/share_zhuyixuan05/public_datasets/open-x/0.1.0",
438
+ help="Path to tensorflow_datasets directory")
439
+ parser.add_argument("--text_encoder_path", type=str,
440
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
441
+ parser.add_argument("--vae_path", type=str,
442
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
443
+ parser.add_argument("--output_dir", type=str,
444
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded")
445
+ parser.add_argument("--max_episodes", type=int, default=10000,
446
+ help="Maximum number of episodes to process (default: 10 for testing)")
447
+
448
+ args = parser.parse_args()
449
+
450
+ # 确保输出目录存在
451
+ os.makedirs(args.output_dir, exist_ok=True)
452
+
453
+ print("🚀 开始编码Open-X Fractal数据集 (基于实际字段结构)...")
454
+ print(f"📁 数据集路径: {args.dataset_path}")
455
+ print(f"💾 输出目录: {args.output_dir}")
456
+ print(f"🔢 最大处理episodes: {args.max_episodes}")
457
+ print("🔧 基于实际observation和action字段的位姿提取方法")
458
+ print("✅ 优先使用 'image' 字段获取图像数据")
459
+
460
+ encode_fractal_dataset(
461
+ args.dataset_path,
462
+ args.text_encoder_path,
463
+ args.vae_path,
464
+ args.output_dir,
465
+ args.max_episodes
466
+ )
scripts/encode_rlbench_video.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 512 , 512
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
76
+ # if i < 1700:
77
+ # continue
78
+ scene_dir = os.path.join(scenes_path, scene_name)
79
+ for j, demo_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
80
+ demo_dir = os.path.join(scene_dir, demo_name)
81
+ for filename in os.listdir(demo_dir):
82
+ # 检查文件是否以.mp4结尾(不区分大小写)
83
+ if filename.lower().endswith('.mp4'):
84
+ # 获取完整路径
85
+ full_path = os.path.join(demo_dir, filename)
86
+ print(full_path)
87
+ save_dir = os.path.join(output_dir,scene_name+'_'+demo_name)
88
+ # print('in:',scene_dir)
89
+ # print('out:',save_dir)
90
+
91
+
92
+
93
+ os.makedirs(save_dir,exist_ok=True)
94
+ # 检查是否已编码
95
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
96
+ if os.path.exists(encoded_path):
97
+ print(f"Scene {scene_name} already encoded, skipping...")
98
+ continue
99
+
100
+ # 加载场景信息
101
+
102
+ scene_cam_path = full_path.replace("side.mp4", "data.npy")
103
+ print(scene_cam_path)
104
+ if not os.path.exists(scene_cam_path):
105
+ continue
106
+
107
+ # with np.load(scene_cam_path) as data:
108
+ cam_data = np.load(scene_cam_path)
109
+ cam_emb = cam_data
110
+ print(cam_data.shape)
111
+ # with open(scene_cam_path, 'rb') as f:
112
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
113
+
114
+ # 加载视频
115
+ video_path = full_path
116
+ if not os.path.exists(video_path):
117
+ print(f"Video not found: {video_path}")
118
+ continue
119
+
120
+ # try:
121
+ print(f"Encoding scene {scene_name}...Demo {demo_name}")
122
+
123
+ # 加载和编码视频
124
+ video_frames = encoder.load_video_frames(video_path)
125
+ if video_frames is None:
126
+ print(f"Failed to load video: {video_path}")
127
+ continue
128
+
129
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
130
+ print('video shape:',video_frames.shape)
131
+ # 编码视频
132
+ with torch.no_grad():
133
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
134
+
135
+ # 编码文本
136
+ # if processed_count == 0:
137
+ # print('encode prompt!!!')
138
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
139
+ # del encoder.pipe.prompter
140
+ # pdb.set_trace()
141
+ # 保存编码结果
142
+ encoded_data = {
143
+ "latents": latents.cpu(),
144
+ #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
145
+ "cam_emb": cam_emb
146
+ }
147
+ # pdb.set_trace()
148
+ torch.save(encoded_data, encoded_path)
149
+ print(f"Saved encoded data: {encoded_path}")
150
+ processed_count += 1
151
+
152
+ # except Exception as e:
153
+ # print(f"Error encoding scene {scene_name}: {e}")
154
+ # continue
155
+
156
+ print(f"Encoding completed! Processed {processed_count} scenes.")
157
+
158
+ if __name__ == "__main__":
159
+ parser = argparse.ArgumentParser()
160
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/RLBench")
161
+ parser.add_argument("--text_encoder_path", type=str,
162
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
163
+ parser.add_argument("--vae_path", type=str,
164
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
165
+
166
+ parser.add_argument("--output_dir",type=str,
167
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench")
168
+
169
+ args = parser.parse_args()
170
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_sekai_video.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
76
+ # if i < 1700:
77
+ # continue
78
+ scene_dir = os.path.join(scenes_path, scene_name)
79
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
80
+ # print('in:',scene_dir)
81
+ # print('out:',save_dir)
82
+
83
+ if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
84
+ continue
85
+
86
+
87
+ os.makedirs(save_dir,exist_ok=True)
88
+ # 检查是否已编码
89
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
90
+ if os.path.exists(encoded_path):
91
+ print(f"Scene {scene_name} already encoded, skipping...")
92
+ continue
93
+
94
+ # 加载场景信息
95
+
96
+ scene_cam_path = scene_dir.replace(".mp4", ".npz")
97
+ if not os.path.exists(scene_cam_path):
98
+ continue
99
+
100
+ with np.load(scene_cam_path) as data:
101
+ cam_data = data.files
102
+ cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
103
+ # with open(scene_cam_path, 'rb') as f:
104
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
105
+
106
+ # 加载视频
107
+ video_path = scene_dir
108
+ if not os.path.exists(video_path):
109
+ print(f"Video not found: {video_path}")
110
+ continue
111
+
112
+ # try:
113
+ print(f"Encoding scene {scene_name}...")
114
+
115
+ # 加载和编码视频
116
+ video_frames = encoder.load_video_frames(video_path)
117
+ if video_frames is None:
118
+ print(f"Failed to load video: {video_path}")
119
+ continue
120
+
121
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
122
+ print('video shape:',video_frames.shape)
123
+ # 编码视频
124
+ with torch.no_grad():
125
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
126
+
127
+ # 编码文本
128
+ if processed_count == 0:
129
+ print('encode prompt!!!')
130
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
131
+ del encoder.pipe.prompter
132
+ # pdb.set_trace()
133
+ # 保存编码结果
134
+ encoded_data = {
135
+ "latents": latents.cpu(),
136
+ #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
137
+ "cam_emb": cam_emb
138
+ }
139
+ # pdb.set_trace()
140
+ torch.save(encoded_data, encoded_path)
141
+ print(f"Saved encoded data: {encoded_path}")
142
+ processed_count += 1
143
+
144
+ # except Exception as e:
145
+ # print(f"Error encoding scene {scene_name}: {e}")
146
+ # continue
147
+
148
+ print(f"Encoding completed! Processed {processed_count} scenes.")
149
+
150
+ if __name__ == "__main__":
151
+ parser = argparse.ArgumentParser()
152
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
153
+ parser.add_argument("--text_encoder_path", type=str,
154
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
155
+ parser.add_argument("--vae_path", type=str,
156
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
157
+
158
+ parser.add_argument("--output_dir",type=str,
159
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
160
+
161
+ args = parser.parse_args()
162
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_sekai_walking.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ class VideoEncoder(pl.LightningModule):
19
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
20
+ super().__init__()
21
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
22
+ model_manager.load_models([text_encoder_path, vae_path])
23
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
24
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
25
+
26
+ self.frame_process = v2.Compose([
27
+ # v2.CenterCrop(size=(900, 1600)),
28
+ # v2.Resize(size=(900, 1600), antialias=True),
29
+ v2.ToTensor(),
30
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
31
+ ])
32
+
33
+ def crop_and_resize(self, image):
34
+ width, height = image.size
35
+ # print(width,height)
36
+ width_ori, height_ori_ = 832 , 480
37
+ image = v2.functional.resize(
38
+ image,
39
+ (round(height_ori_), round(width_ori)),
40
+ interpolation=v2.InterpolationMode.BILINEAR
41
+ )
42
+ return image
43
+
44
+ def load_video_frames(self, video_path):
45
+ """加载完整视频"""
46
+ reader = imageio.get_reader(video_path)
47
+ frames = []
48
+
49
+ for frame_data in reader:
50
+ frame = Image.fromarray(frame_data)
51
+ frame = self.crop_and_resize(frame)
52
+ frame = self.frame_process(frame)
53
+ frames.append(frame)
54
+
55
+ reader.close()
56
+
57
+ if len(frames) == 0:
58
+ return None
59
+
60
+ frames = torch.stack(frames, dim=0)
61
+ frames = rearrange(frames, "T C H W -> C T H W")
62
+ return frames
63
+
64
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
65
+ """编码所有场景的视频"""
66
+
67
+ encoder = VideoEncoder(text_encoder_path, vae_path)
68
+ encoder = encoder.cuda()
69
+ encoder.pipe.device = "cuda"
70
+
71
+ processed_count = 0
72
+
73
+ processed_chunk_count = 0
74
+
75
+ prompt_emb = 0
76
+
77
+ os.makedirs(output_dir,exist_ok=True)
78
+ chunk_size = 300
79
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
80
+ # print('index-----:',type(i))
81
+ # if i < 3000 :#or i >=2000:
82
+ # # print('index-----:',i)
83
+ # continue
84
+ # print('index:',i)
85
+ print('index:',i)
86
+ scene_dir = os.path.join(scenes_path, scene_name)
87
+
88
+ # save_dir = os.path.join(output_dir,scene_name.split('.')[0])
89
+ # print('in:',scene_dir)
90
+ # print('out:',save_dir)
91
+
92
+ if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
93
+ continue
94
+
95
+
96
+ scene_cam_path = scene_dir.replace(".mp4", ".npz")
97
+ if not os.path.exists(scene_cam_path):
98
+ continue
99
+
100
+ with np.load(scene_cam_path) as data:
101
+ cam_data = data.files
102
+ cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
103
+ # with open(scene_cam_path, 'rb') as f:
104
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
105
+
106
+ video_name = scene_name[:-4].split('_')[0]
107
+ start_frame = int(scene_name[:-4].split('_')[1])
108
+ end_frame = int(scene_name[:-4].split('_')[2])
109
+
110
+ sampled_range = range(start_frame, end_frame , chunk_size)
111
+ sampled_frames = list(sampled_range)
112
+
113
+ sampled_chunk_end = sampled_frames[0] + 300
114
+ start_str = f"{sampled_frames[0]:07d}"
115
+ end_str = f"{sampled_chunk_end:07d}"
116
+
117
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
118
+ save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
119
+
120
+ if os.path.exists(save_chunk_path):
121
+ print(f"Video {video_name} already encoded, skipping...")
122
+ continue
123
+
124
+ # 加载视频
125
+ video_path = scene_dir
126
+ if not os.path.exists(video_path):
127
+ print(f"Video not found: {video_path}")
128
+ continue
129
+
130
+ video_frames = encoder.load_video_frames(video_path)
131
+ if video_frames is None:
132
+ print(f"Failed to load video: {video_path}")
133
+ continue
134
+
135
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
136
+ print('video shape:',video_frames.shape)
137
+
138
+
139
+
140
+ # print(sampled_frames)
141
+
142
+ print(f"Encoding scene {scene_name}...")
143
+ for sampled_chunk_start in sampled_frames:
144
+ sampled_chunk_end = sampled_chunk_start + 300
145
+ start_str = f"{sampled_chunk_start:07d}"
146
+ end_str = f"{sampled_chunk_end:07d}"
147
+
148
+ # 生成保存目录名(假设video_name已定义)
149
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
150
+ save_chunk_dir = os.path.join(output_dir,chunk_name)
151
+
152
+ os.makedirs(save_chunk_dir,exist_ok=True)
153
+ print(f"Encoding chunk {chunk_name}...")
154
+
155
+ encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
156
+
157
+ if os.path.exists(encoded_path):
158
+ print(f"Chunk {chunk_name} already encoded, skipping...")
159
+ continue
160
+
161
+
162
+ chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
163
+ # print('extrinsic:',cam_emb['extrinsic'].shape)
164
+ chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
165
+ 'intrinsic':cam_emb['intrinsic']}
166
+
167
+ # print('chunk shape:',chunk_frames.shape)
168
+
169
+ with torch.no_grad():
170
+ latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
171
+
172
+ # 编码文本
173
+ # if processed_count == 0:
174
+ # print('encode prompt!!!')
175
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
176
+ # del encoder.pipe.prompter
177
+ # pdb.set_trace()
178
+ # 保存编码结果
179
+ encoded_data = {
180
+ "latents": latents.cpu(),
181
+ # "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
182
+ "cam_emb": chunk_cam_emb
183
+ }
184
+ # pdb.set_trace()
185
+ torch.save(encoded_data, encoded_path)
186
+ print(f"Saved encoded data: {encoded_path}")
187
+ processed_chunk_count += 1
188
+
189
+ processed_count += 1
190
+
191
+ print("Encoded scene numebr:",processed_count)
192
+ print("Encoded chunk numebr:",processed_chunk_count)
193
+
194
+ # os.makedirs(save_dir,exist_ok=True)
195
+ # # 检查是否已编码
196
+ # encoded_path = os.path.join(save_dir, "encoded_video.pth")
197
+ # if os.path.exists(encoded_path):
198
+ # print(f"Scene {scene_name} already encoded, skipping...")
199
+ # continue
200
+
201
+ # 加载场景信息
202
+
203
+
204
+
205
+ # try:
206
+ # print(f"Encoding scene {scene_name}...")
207
+
208
+ # 加载和编码视频
209
+
210
+ # 编码视频
211
+ # with torch.no_grad():
212
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
213
+
214
+ # # 编码文本
215
+ # if processed_count == 0:
216
+ # print('encode prompt!!!')
217
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
218
+ # del encoder.pipe.prompter
219
+ # # pdb.set_trace()
220
+ # # 保存编码结果
221
+ # encoded_data = {
222
+ # "latents": latents.cpu(),
223
+ # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
224
+ # "cam_emb": cam_emb
225
+ # }
226
+ # # pdb.set_trace()
227
+ # torch.save(encoded_data, encoded_path)
228
+ # print(f"Saved encoded data: {encoded_path}")
229
+ # processed_count += 1
230
+
231
+ # except Exception as e:
232
+ # print(f"Error encoding scene {scene_name}: {e}")
233
+ # continue
234
+
235
+ print(f"Encoding completed! Processed {processed_count} scenes.")
236
+
237
+ if __name__ == "__main__":
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
240
+ parser.add_argument("--text_encoder_path", type=str,
241
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
242
+ parser.add_argument("--vae_path", type=str,
243
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
244
+
245
+ parser.add_argument("--output_dir",type=str,
246
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
247
+
248
+ args = parser.parse_args()
249
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_spatialvid.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ from scipy.spatial.transform import Slerp
20
+ from scipy.spatial.transform import Rotation as R
21
+
22
+ def interpolate_camera_poses(original_frames, original_poses, target_frames):
23
+ """
24
+ 对相机姿态进行插值,生成目标帧对应的姿态参数
25
+
26
+ 参数:
27
+ original_frames: 原始帧索引列表,如[0,6,12,...]
28
+ original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
29
+ target_frames: 目标帧索引列表,如[0,4,8,12,...]
30
+
31
+ 返回:
32
+ target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
33
+ """
34
+ # 确保输入有效
35
+ print('original_frames:',len(original_frames))
36
+ print('original_poses:',len(original_poses))
37
+ if len(original_frames) != len(original_poses):
38
+ raise ValueError("原始帧数量与姿态数量不匹配")
39
+
40
+ if original_poses.shape[1] != 7:
41
+ raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
42
+
43
+ target_poses = []
44
+
45
+ # 提取旋转部分并转换为Rotation对象
46
+ rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
47
+
48
+ for t in target_frames:
49
+ # 找到t前后的原始帧索引
50
+ idx = np.searchsorted(original_frames, t, side='left')
51
+
52
+ # 处理边界情况
53
+ if idx == 0:
54
+ # 使用第一个姿态
55
+ target_poses.append(original_poses[0])
56
+ continue
57
+ if idx >= len(original_frames):
58
+ # 使用最后一个姿态
59
+ target_poses.append(original_poses[-1])
60
+ continue
61
+
62
+ # 获取前后帧的信息
63
+ t_prev, t_next = original_frames[idx-1], original_frames[idx]
64
+ pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
65
+
66
+ # 计算插值权重
67
+ alpha = (t - t_prev) / (t_next - t_prev)
68
+
69
+ # 1. 平移向量的线性插值
70
+ translation_prev = pose_prev[:3]
71
+ translation_next = pose_next[:3]
72
+ interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
73
+
74
+ # 2. 旋转四元数的球面线性插值(SLERP)
75
+ # 创建Slerp对象
76
+ slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
77
+ interpolated_rotation = slerp(t)
78
+
79
+ # 组合平移和旋转
80
+ interpolated_pose = np.concatenate([
81
+ interpolated_translation,
82
+ interpolated_rotation.as_quat() # 转换回四元数
83
+ ])
84
+
85
+ target_poses.append(interpolated_pose)
86
+
87
+ return np.array(target_poses)
88
+
89
+
90
+ class VideoEncoder(pl.LightningModule):
91
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
92
+ super().__init__()
93
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
94
+ model_manager.load_models([text_encoder_path, vae_path])
95
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
96
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
97
+
98
+ self.frame_process = v2.Compose([
99
+ # v2.CenterCrop(size=(900, 1600)),
100
+ # v2.Resize(size=(900, 1600), antialias=True),
101
+ v2.ToTensor(),
102
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
103
+ ])
104
+
105
+ def crop_and_resize(self, image):
106
+ width, height = image.size
107
+ # print(width,height)
108
+ width_ori, height_ori_ = 832 , 480
109
+ image = v2.functional.resize(
110
+ image,
111
+ (round(height_ori_), round(width_ori)),
112
+ interpolation=v2.InterpolationMode.BILINEAR
113
+ )
114
+ return image
115
+
116
+ def load_video_frames(self, video_path):
117
+ """加载完整视频"""
118
+ reader = imageio.get_reader(video_path)
119
+ frames = []
120
+
121
+ for frame_data in reader:
122
+ frame = Image.fromarray(frame_data)
123
+ frame = self.crop_and_resize(frame)
124
+ frame = self.frame_process(frame)
125
+ frames.append(frame)
126
+
127
+ reader.close()
128
+
129
+ if len(frames) == 0:
130
+ return None
131
+
132
+ frames = torch.stack(frames, dim=0)
133
+ frames = rearrange(frames, "T C H W -> C T H W")
134
+ return frames
135
+
136
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
137
+ """编码所有场景的视频"""
138
+
139
+ encoder = VideoEncoder(text_encoder_path, vae_path)
140
+ encoder = encoder.cuda()
141
+ encoder.pipe.device = "cuda"
142
+
143
+ processed_count = 0
144
+
145
+ processed_chunk_count = 0
146
+
147
+ prompt_emb = 0
148
+
149
+ metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
150
+
151
+
152
+ os.makedirs(output_dir,exist_ok=True)
153
+ chunk_size = 300
154
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
155
+
156
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
157
+ # print('index-----:',type(i))
158
+ if i < 3 :#or i >=2000:
159
+ # # print('index-----:',i)
160
+ continue
161
+ # print('index:',i)
162
+ print('group:',i)
163
+ scene_dir = os.path.join(scenes_path, scene_name)
164
+
165
+ # save_dir = os.path.join(output_dir,scene_name.split('.')[0])
166
+ print('in:',scene_dir)
167
+ # print('out:',save_dir)
168
+ for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
169
+
170
+ # if j < 1000 :#or i >=2000:
171
+ # print('index:',j)
172
+ # continue
173
+ print(video_name)
174
+ video_path = os.path.join(scene_dir, video_name)
175
+ if not video_path.endswith(".mp4"):# or os.path.isdir(output_dir):
176
+ continue
177
+
178
+ video_info = metadata[metadata['id'] == video_name[:-4]]
179
+ num_frames = video_info['num frames'].iloc[0]
180
+
181
+ scene_cam_dir = video_path.replace( "videos","annotations")[:-4]
182
+ scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
183
+
184
+ scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
185
+
186
+ with open(scene_caption_path, 'r', encoding='utf-8') as f:
187
+ caption_data = json.load(f)
188
+ caption = caption_data["SceneSummary"]
189
+ if not os.path.exists(scene_cam_path):
190
+ print(f"Pose not found: {scene_cam_path}")
191
+ continue
192
+
193
+ camera_poses = np.load(scene_cam_path)
194
+ cam_data_len = camera_poses.shape[0]
195
+
196
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
197
+ # with open(scene_cam_path, 'rb') as f:
198
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
199
+
200
+ # 加载视频
201
+ # video_path = scene_dir
202
+ if not os.path.exists(video_path):
203
+ print(f"Video not found: {video_path}")
204
+ continue
205
+
206
+ start_str = f"{0:07d}"
207
+ end_str = f"{chunk_size:07d}"
208
+ chunk_name = f"{video_name[:-4]}_{start_str}_{end_str}"
209
+ first_save_chunk_dir = os.path.join(output_dir,chunk_name)
210
+
211
+ first_chunk_encoded_path = os.path.join(first_save_chunk_dir, "encoded_video.pth")
212
+ # print(first_chunk_encoded_path)
213
+ if os.path.exists(first_chunk_encoded_path):
214
+ data = torch.load(first_chunk_encoded_path,weights_only=False)
215
+ if 'latents' in data:
216
+ video_frames = 1
217
+ else:
218
+ video_frames = encoder.load_video_frames(video_path)
219
+ if video_frames is None:
220
+ print(f"Failed to load video: {video_path}")
221
+ continue
222
+ print('video shape:',video_frames.shape)
223
+
224
+
225
+
226
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
227
+ print('video shape:',video_frames.shape)
228
+
229
+ video_name = video_name[:-4].split('_')[0]
230
+ start_frame = 0
231
+ end_frame = num_frames
232
+ # print("num_frames:",num_frames)
233
+
234
+ cam_interval = end_frame // (cam_data_len - 1)
235
+
236
+ cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
237
+ cam_frames = np.round(cam_frames).astype(int)
238
+ cam_frames = cam_frames.tolist()
239
+ # list(range(0, end_frame + 1 , cam_interval))
240
+
241
+
242
+ sampled_range = range(start_frame, end_frame , chunk_size)
243
+ sampled_frames = list(sampled_range)
244
+
245
+ sampled_chunk_end = sampled_frames[0] + chunk_size
246
+ start_str = f"{sampled_frames[0]:07d}"
247
+ end_str = f"{sampled_chunk_end:07d}"
248
+
249
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
250
+ # save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
251
+
252
+ # if os.path.exists(save_chunk_path):
253
+ # print(f"Video {video_name} already encoded, skipping...")
254
+ # continue
255
+
256
+
257
+
258
+
259
+
260
+ # print(sampled_frames)
261
+
262
+ print(f"Encoding scene {video_name}...")
263
+ chunk_count_in_one_video = 0
264
+ for sampled_chunk_start in sampled_frames:
265
+ if num_frames - sampled_chunk_start < 100:
266
+ continue
267
+ sampled_chunk_end = sampled_chunk_start + chunk_size
268
+ start_str = f"{sampled_chunk_start:07d}"
269
+ end_str = f"{sampled_chunk_end:07d}"
270
+
271
+ resample_cam_frame = list(range(sampled_chunk_start, sampled_chunk_end , 4))
272
+
273
+ # 生成保存目录名(假设video_name已定义)
274
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
275
+ save_chunk_dir = os.path.join(output_dir,chunk_name)
276
+
277
+ os.makedirs(save_chunk_dir,exist_ok=True)
278
+ print(f"Encoding chunk {chunk_name}...")
279
+
280
+ encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
281
+
282
+ missing_keys = required_keys
283
+ if os.path.exists(encoded_path):
284
+ print('error:',encoded_path)
285
+ data = torch.load(encoded_path,weights_only=False)
286
+ missing_keys = [key for key in required_keys if key not in data]
287
+ # print(missing_keys)
288
+ # print(f"Chunk {chunk_name} already encoded, skipping...")
289
+ if missing_keys:
290
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
291
+ if len(missing_keys) == 0 :
292
+ continue
293
+ else:
294
+ print(f"警告: 缺少pth文件: {encoded_path}")
295
+ if not isinstance(video_frames, torch.Tensor):
296
+
297
+ video_frames = encoder.load_video_frames(video_path)
298
+ if video_frames is None:
299
+ print(f"Failed to load video: {video_path}")
300
+ continue
301
+
302
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
303
+
304
+ print('video shape:',video_frames.shape)
305
+ if "latents" in missing_keys:
306
+ chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
307
+
308
+ # print('extrinsic:',cam_emb['extrinsic'].shape)
309
+
310
+ # chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
311
+ # 'intrinsic':cam_emb['intrinsic']}
312
+
313
+ # print('chunk shape:',chunk_frames.shape)
314
+
315
+ with torch.no_grad():
316
+ latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
317
+ else:
318
+ latents = data['latents']
319
+ if "cam_emb" in missing_keys:
320
+ cam_emb = interpolate_camera_poses(cam_frames, camera_poses,resample_cam_frame)
321
+ chunk_cam_emb ={'extrinsic':cam_emb}
322
+ print(f"视频长度:{chunk_size},重采样相机长度:{cam_emb.shape[0]}")
323
+ else:
324
+ chunk_cam_emb = data['cam_emb']
325
+
326
+ if "prompt_emb" in missing_keys:
327
+ # 编码文本
328
+ if chunk_count_in_one_video == 0:
329
+ print(caption)
330
+ with torch.no_grad():
331
+ prompt_emb = encoder.pipe.encode_prompt(caption)
332
+ else:
333
+ prompt_emb = data['prompt_emb']
334
+
335
+ # del encoder.pipe.prompter
336
+ # pdb.set_trace()
337
+ # 保存编码结果
338
+ encoded_data = {
339
+ "latents": latents.cpu(),
340
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
341
+ "cam_emb": chunk_cam_emb
342
+ }
343
+ # pdb.set_trace()
344
+ torch.save(encoded_data, encoded_path)
345
+ print(f"Saved encoded data: {encoded_path}")
346
+ processed_chunk_count += 1
347
+ chunk_count_in_one_video += 1
348
+
349
+ processed_count += 1
350
+
351
+ print("Encoded scene numebr:",processed_count)
352
+ print("Encoded chunk numebr:",processed_chunk_count)
353
+
354
+ # os.makedirs(save_dir,exist_ok=True)
355
+ # # 检查是否已编码
356
+ # encoded_path = os.path.join(save_dir, "encoded_video.pth")
357
+ # if os.path.exists(encoded_path):
358
+ # print(f"Scene {scene_name} already encoded, skipping...")
359
+ # continue
360
+
361
+ # 加载场景信息
362
+
363
+
364
+
365
+ # try:
366
+ # print(f"Encoding scene {scene_name}...")
367
+
368
+ # 加载和编码视频
369
+
370
+ # 编码视频
371
+ # with torch.no_grad():
372
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
373
+
374
+ # # 编码文本
375
+ # if processed_count == 0:
376
+ # print('encode prompt!!!')
377
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
378
+ # del encoder.pipe.prompter
379
+ # # pdb.set_trace()
380
+ # # 保存编码结果
381
+ # encoded_data = {
382
+ # "latents": latents.cpu(),
383
+ # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
384
+ # "cam_emb": cam_emb
385
+ # }
386
+ # # pdb.set_trace()
387
+ # torch.save(encoded_data, encoded_path)
388
+ # print(f"Saved encoded data: {encoded_path}")
389
+ # processed_count += 1
390
+
391
+ # except Exception as e:
392
+ # print(f"Error encoding scene {scene_name}: {e}")
393
+ # continue
394
+
395
+ print(f"Encoding completed! Processed {processed_count} scenes.")
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser()
399
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
400
+ parser.add_argument("--text_encoder_path", type=str,
401
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
402
+ parser.add_argument("--vae_path", type=str,
403
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
404
+
405
+ parser.add_argument("--output_dir",type=str,
406
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
407
+
408
+ args = parser.parse_args()
409
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_spatialvid_first_frame.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ from scipy.spatial.transform import Slerp
20
+ from scipy.spatial.transform import Rotation as R
21
+
22
+ def interpolate_camera_poses(original_frames, original_poses, target_frames):
23
+ """
24
+ 对相机姿态进行插值,生成目标帧对应的姿态参数
25
+
26
+ 参数:
27
+ original_frames: 原始帧索引列表,如[0,6,12,...]
28
+ original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
29
+ target_frames: 目标帧索引列表,如[0,4,8,12,...]
30
+
31
+ 返回:
32
+ target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
33
+ """
34
+ # 确保输入有效
35
+ print('original_frames:',len(original_frames))
36
+ print('original_poses:',len(original_poses))
37
+ if len(original_frames) != len(original_poses):
38
+ raise ValueError("原始帧数量与姿态数量不匹配")
39
+
40
+ if original_poses.shape[1] != 7:
41
+ raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
42
+
43
+ target_poses = []
44
+
45
+ # 提取旋转部分并转换为Rotation对象
46
+ rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
47
+
48
+ for t in target_frames:
49
+ # 找到t前后的原始帧索引
50
+ idx = np.searchsorted(original_frames, t, side='left')
51
+
52
+ # 处理边界情况
53
+ if idx == 0:
54
+ # 使用第一个姿态
55
+ target_poses.append(original_poses[0])
56
+ continue
57
+ if idx >= len(original_frames):
58
+ # 使用最后一个姿态
59
+ target_poses.append(original_poses[-1])
60
+ continue
61
+
62
+ # 获取前后帧的信息
63
+ t_prev, t_next = original_frames[idx-1], original_frames[idx]
64
+ pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
65
+
66
+ # 计算插值权重
67
+ alpha = (t - t_prev) / (t_next - t_prev)
68
+
69
+ # 1. 平移向量的线性插值
70
+ translation_prev = pose_prev[:3]
71
+ translation_next = pose_next[:3]
72
+ interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
73
+
74
+ # 2. 旋转四元数的球面线性插值(SLERP)
75
+ # 创建Slerp对象
76
+ slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
77
+ interpolated_rotation = slerp(t)
78
+
79
+ # 组合平移和旋转
80
+ interpolated_pose = np.concatenate([
81
+ interpolated_translation,
82
+ interpolated_rotation.as_quat() # 转换回四元数
83
+ ])
84
+
85
+ target_poses.append(interpolated_pose)
86
+
87
+ return np.array(target_poses)
88
+
89
+ class VideoEncoder(pl.LightningModule):
90
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
91
+ super().__init__()
92
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
93
+ model_manager.load_models([text_encoder_path, vae_path])
94
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
95
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
96
+
97
+ self.frame_process = v2.Compose([
98
+ v2.ToTensor(),
99
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
100
+ ])
101
+
102
+ def crop_and_resize(self, image):
103
+ width, height = image.size
104
+ width_ori, height_ori_ = 832 , 480
105
+ image = v2.functional.resize(
106
+ image,
107
+ (round(height_ori_), round(width_ori)),
108
+ interpolation=v2.InterpolationMode.BILINEAR
109
+ )
110
+ return image
111
+
112
+ def load_single_frame(self, video_path, frame_idx):
113
+ """只加载指定的单帧"""
114
+ reader = imageio.get_reader(video_path)
115
+
116
+ try:
117
+ # 直接跳转到指定帧
118
+ frame_data = reader.get_data(frame_idx)
119
+ frame = Image.fromarray(frame_data)
120
+ frame = self.crop_and_resize(frame)
121
+ frame = self.frame_process(frame)
122
+
123
+ # 添加batch和time维度: [C, H, W] -> [1, C, 1, H, W]
124
+ frame = frame.unsqueeze(0).unsqueeze(2)
125
+
126
+ except Exception as e:
127
+ print(f"Error loading frame {frame_idx} from {video_path}: {e}")
128
+ return None
129
+ finally:
130
+ reader.close()
131
+
132
+ return frame
133
+
134
+ def load_video_frames(self, video_path):
135
+ """加载完整视频(保留用于兼容性)"""
136
+ reader = imageio.get_reader(video_path)
137
+ frames = []
138
+
139
+ for frame_data in reader:
140
+ frame = Image.fromarray(frame_data)
141
+ frame = self.crop_and_resize(frame)
142
+ frame = self.frame_process(frame)
143
+ frames.append(frame)
144
+
145
+ reader.close()
146
+
147
+ if len(frames) == 0:
148
+ return None
149
+
150
+ frames = torch.stack(frames, dim=0)
151
+ frames = rearrange(frames, "T C H W -> C T H W")
152
+ return frames
153
+
154
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
155
+ """编码所有场景的视频"""
156
+
157
+ encoder = VideoEncoder(text_encoder_path, vae_path)
158
+ encoder = encoder.cuda()
159
+ encoder.pipe.device = "cuda"
160
+
161
+ processed_count = 0
162
+ processed_chunk_count = 0
163
+
164
+ metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
165
+
166
+ os.makedirs(output_dir,exist_ok=True)
167
+ chunk_size = 300
168
+
169
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
170
+ if i < 2:
171
+ continue
172
+ print('group:',i)
173
+ scene_dir = os.path.join(scenes_path, scene_name)
174
+
175
+ print('in:',scene_dir)
176
+ for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
177
+ print(video_name)
178
+ video_path = os.path.join(scene_dir, video_name)
179
+ if not video_path.endswith(".mp4"):
180
+ continue
181
+
182
+ video_info = metadata[metadata['id'] == video_name[:-4]]
183
+ num_frames = video_info['num frames'].iloc[0]
184
+
185
+ scene_cam_dir = video_path.replace("videos","annotations")[:-4]
186
+ scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
187
+ scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
188
+
189
+ with open(scene_caption_path, 'r', encoding='utf-8') as f:
190
+ caption_data = json.load(f)
191
+ caption = caption_data["SceneSummary"]
192
+
193
+ if not os.path.exists(scene_cam_path):
194
+ print(f"Pose not found: {scene_cam_path}")
195
+ continue
196
+
197
+ camera_poses = np.load(scene_cam_path)
198
+ cam_data_len = camera_poses.shape[0]
199
+
200
+ if not os.path.exists(video_path):
201
+ print(f"Video not found: {video_path}")
202
+ continue
203
+
204
+ video_name = video_name[:-4].split('_')[0]
205
+ start_frame = 0
206
+ end_frame = num_frames
207
+
208
+ cam_interval = end_frame // (cam_data_len - 1)
209
+
210
+ cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
211
+ cam_frames = np.round(cam_frames).astype(int)
212
+ cam_frames = cam_frames.tolist()
213
+
214
+ sampled_range = range(start_frame, end_frame, chunk_size)
215
+ sampled_frames = list(sampled_range)
216
+
217
+ print(f"Encoding scene {video_name}...")
218
+ chunk_count_in_one_video = 0
219
+
220
+ for sampled_chunk_start in sampled_frames:
221
+ if num_frames - sampled_chunk_start < 100:
222
+ continue
223
+
224
+ sampled_chunk_end = sampled_chunk_start + chunk_size
225
+ start_str = f"{sampled_chunk_start:07d}"
226
+ end_str = f"{sampled_chunk_end:07d}"
227
+
228
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
229
+ save_chunk_dir = os.path.join(output_dir, chunk_name)
230
+ os.makedirs(save_chunk_dir, exist_ok=True)
231
+
232
+ print(f"Encoding chunk {chunk_name}...")
233
+
234
+ first_latent_path = os.path.join(save_chunk_dir, "first_latent.pth")
235
+
236
+ if os.path.exists(first_latent_path):
237
+ print(f"First latent for chunk {chunk_name} already exists, skipping...")
238
+ continue
239
+
240
+ # 只加载需要的那一帧
241
+ first_frame_idx = sampled_chunk_start
242
+ print(f"first_frame:{first_frame_idx}")
243
+ first_frame = encoder.load_single_frame(video_path, first_frame_idx)
244
+
245
+ if first_frame is None:
246
+ print(f"Failed to load frame {first_frame_idx} from: {video_path}")
247
+ continue
248
+
249
+ first_frame = first_frame.to("cuda", dtype=torch.bfloat16)
250
+
251
+ # 重复4次
252
+ repeated_first_frame = first_frame.repeat(1, 1, 4, 1, 1)
253
+ print(f"Repeated first frame shape: {repeated_first_frame.shape}")
254
+
255
+ with torch.no_grad():
256
+ first_latents = encoder.pipe.encode_video(repeated_first_frame, **encoder.tiler_kwargs)[0]
257
+
258
+ first_latent_data = {
259
+ "latents": first_latents.cpu(),
260
+ }
261
+ torch.save(first_latent_data, first_latent_path)
262
+ print(f"Saved first latent: {first_latent_path}")
263
+
264
+ processed_chunk_count += 1
265
+ chunk_count_in_one_video += 1
266
+
267
+ processed_count += 1
268
+ print("Encoded scene number:", processed_count)
269
+ print("Encoded chunk number:", processed_chunk_count)
270
+
271
+ print(f"Encoding completed! Processed {processed_count} scenes.")
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
276
+ parser.add_argument("--text_encoder_path", type=str,
277
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
278
+ parser.add_argument("--vae_path", type=str,
279
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
280
+
281
+ parser.add_argument("--output_dir",type=str,
282
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
283
+
284
+ args = parser.parse_args()
285
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/hud_logo.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import os
3
+
4
+ os.makedirs("wasd_ui", exist_ok=True)
5
+
6
+ # UI sizes (small)
7
+ key_size = (48, 48)
8
+ corner = 10
9
+ bg_padding = 6
10
+ font = ImageFont.truetype("arial.ttf", 28) # 替换成本地支持的字体
11
+
12
+ def rounded_rect(im, bbox, radius, fill):
13
+ draw = ImageDraw.Draw(im, "RGBA")
14
+ draw.rounded_rectangle(bbox, radius=radius, fill=fill)
15
+
16
+ # background plate
17
+ bg_width = key_size[0] * 3 + bg_padding * 4
18
+ bg_height = key_size[1] * 2 + bg_padding * 4
19
+ ui_bg = Image.new("RGBA", (bg_width, bg_height), (0,0,0,0))
20
+ rounded_rect(ui_bg, (0,0,bg_width,bg_height), corner, (0,0,0,140))
21
+ ui_bg.save("wasd_ui/ui_background.png")
22
+
23
+ keys = ["W","A","S","D"]
24
+
25
+ def draw_key(char, active):
26
+ im = Image.new("RGBA", key_size, (0,0,0,0))
27
+ rounded_rect(im, (0,0,key_size[0],key_size[1]), corner,
28
+ (255,255,255,230) if active else (200,200,200,180))
29
+ draw = ImageDraw.Draw(im)
30
+ color = (0,0,0) if active else (50,50,50)
31
+ w,h = draw.textsize(char, font=font)
32
+ draw.text(((key_size[0]-w)//2,(key_size[1]-h)//2),
33
+ char, font=font, fill=color)
34
+ return im
35
+
36
+ for k in keys:
37
+ draw_key(k, False).save(f"wasd_ui/key_{k}_idle.png")
38
+ draw_key(k, True).save(f"wasd_ui/key_{k}_active.png")
39
+
40
+ print("✅ WASD UI assets generated in ./wasd_ui/")
scripts/infer_demo.py ADDED
@@ -0,0 +1,1458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
5
+ sys.path.append(ROOT_DIR)
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from PIL import Image
11
+ import imageio
12
+ import json
13
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
14
+ import argparse
15
+ from torchvision.transforms import v2
16
+ from einops import rearrange
17
+ import random
18
+ import copy
19
+ from datetime import datetime
20
+
21
+ def compute_relative_pose_matrix(pose1, pose2):
22
+ """
23
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
24
+
25
+ 参数:
26
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
27
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
28
+
29
+ 返回:
30
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
31
+ """
32
+ # 分离平移向量和四元数
33
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
34
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
35
+ t2 = pose2[:3] # 第i+1帧平移
36
+ q2 = pose2[3:] # 第i+1帧四元数
37
+
38
+ # 1. 计算相对旋转矩阵 R_rel
39
+ rot1 = R.from_quat(q1) # 第i帧旋转
40
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
41
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
42
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
43
+
44
+ # 2. 计算相对平移向量 t_rel
45
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
46
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
47
+
48
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
49
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
50
+
51
+ return relative_matrix
52
+
53
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
54
+ """从pth文件加载预编码的视频数据"""
55
+ print(f"Loading encoded video from {pth_path}")
56
+
57
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
58
+ full_latents = encoded_data['latents'] # [C, T, H, W]
59
+
60
+ print(f"Full latents shape: {full_latents.shape}")
61
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
62
+
63
+ if start_frame + num_frames > full_latents.shape[1]:
64
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
65
+
66
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
67
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
68
+
69
+ return condition_latents, encoded_data
70
+
71
+
72
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
73
+ """计算相机B相对于相机A的相对位姿矩阵"""
74
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
75
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
76
+
77
+ if use_torch:
78
+ if not isinstance(pose_a, torch.Tensor):
79
+ pose_a = torch.from_numpy(pose_a).float()
80
+ if not isinstance(pose_b, torch.Tensor):
81
+ pose_b = torch.from_numpy(pose_b).float()
82
+
83
+ pose_a_inv = torch.inverse(pose_a)
84
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
85
+ else:
86
+ if not isinstance(pose_a, np.ndarray):
87
+ pose_a = np.array(pose_a, dtype=np.float32)
88
+ if not isinstance(pose_b, np.ndarray):
89
+ pose_b = np.array(pose_b, dtype=np.float32)
90
+
91
+ pose_a_inv = np.linalg.inv(pose_a)
92
+ relative_pose = np.matmul(pose_b, pose_a_inv)
93
+
94
+ return relative_pose
95
+
96
+
97
+ def replace_dit_model_in_manager():
98
+ """替换DiT模型类为MoE版本"""
99
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
100
+ from diffsynth.configs.model_config import model_loader_configs
101
+
102
+ for i, config in enumerate(model_loader_configs):
103
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
104
+
105
+ if 'wan_video_dit' in model_names:
106
+ new_model_names = []
107
+ new_model_classes = []
108
+
109
+ for name, cls in zip(model_names, model_classes):
110
+ if name == 'wan_video_dit':
111
+ new_model_names.append(name)
112
+ new_model_classes.append(WanModelMoe)
113
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
114
+ else:
115
+ new_model_names.append(name)
116
+ new_model_classes.append(cls)
117
+
118
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
119
+
120
+
121
+ def add_framepack_components(dit_model):
122
+ """添加FramePack相关组件"""
123
+ if not hasattr(dit_model, 'clean_x_embedder'):
124
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
125
+
126
+ class CleanXEmbedder(nn.Module):
127
+ def __init__(self, inner_dim):
128
+ super().__init__()
129
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
130
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
131
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
132
+
133
+ def forward(self, x, scale="1x"):
134
+ if scale == "1x":
135
+ x = x.to(self.proj.weight.dtype)
136
+ return self.proj(x)
137
+ elif scale == "2x":
138
+ x = x.to(self.proj_2x.weight.dtype)
139
+ return self.proj_2x(x)
140
+ elif scale == "4x":
141
+ x = x.to(self.proj_4x.weight.dtype)
142
+ return self.proj_4x(x)
143
+ else:
144
+ raise ValueError(f"Unsupported scale: {scale}")
145
+
146
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
147
+ model_dtype = next(dit_model.parameters()).dtype
148
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
149
+ print("✅ 添加了FramePack的clean_x_embedder组件")
150
+
151
+
152
+ def add_moe_components(dit_model, moe_config):
153
+ """🔧 添加MoE相关组件 - 修正版本"""
154
+ if not hasattr(dit_model, 'moe_config'):
155
+ dit_model.moe_config = moe_config
156
+ print("✅ 添加了MoE配置到模型")
157
+ dit_model.top_k = moe_config.get("top_k", 1)
158
+
159
+ # 为每个block动态添加MoE组件
160
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
161
+ unified_dim = moe_config.get("unified_dim", 25)
162
+ num_experts = moe_config.get("num_experts", 4)
163
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
164
+ dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
165
+ dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
166
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
167
+ dit_model.global_router = nn.Linear(unified_dim, num_experts)
168
+
169
+
170
+ for i, block in enumerate(dit_model.blocks):
171
+ # MoE网络 - 输入unified_dim,输出dim
172
+ block.moe = MultiModalMoE(
173
+ unified_dim=unified_dim,
174
+ output_dim=dim, # 输出维度匹配transformer block的dim
175
+ num_experts=moe_config.get("num_experts", 4),
176
+ top_k=moe_config.get("top_k", 2)
177
+ )
178
+
179
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
180
+
181
+
182
+ def generate_sekai_camera_embeddings_sliding(
183
+ cam_data,
184
+ start_frame,
185
+ initial_condition_frames,
186
+ new_frames,
187
+ total_generated,
188
+ use_real_poses=True,
189
+ direction="left"):
190
+ """
191
+ 为Sekai数据集生成camera embeddings - 滑动窗口版本
192
+
193
+ Args:
194
+ cam_data: 包含Sekai相机外参的字典, 键'extrinsic'对应一个N*4*4的numpy数组
195
+ start_frame: 当前生成起始帧索引
196
+ initial_condition_frames: 初始条件帧数
197
+ new_frames: 本次生成的新帧数
198
+ total_generated: 已生成的总帧数
199
+ use_real_poses: 是否使用真实的Sekai相机位姿
200
+ direction: 相机运动方向,默认为"left"
201
+
202
+ Returns:
203
+ camera_embedding: 形状为(M, 3*4 + 1)的torch张量, M为生成的总帧数
204
+ """
205
+ time_compression_ratio = 4
206
+
207
+ # 计算FramePack实际需要的camera帧数
208
+ # 1帧初始 + 16帧4x + 2帧2x + 1帧1x + new_frames
209
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
210
+
211
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
212
+ print("🔧 使用真实Sekai camera数据")
213
+ cam_extrinsic = cam_data['extrinsic']
214
+
215
+ # 确保生成足够长的camera序列
216
+ max_needed_frames = max(
217
+ start_frame + initial_condition_frames + new_frames,
218
+ framepack_needed_frames,
219
+ 30
220
+ )
221
+
222
+ print(f"🔧 计算Sekai camera序列长度:")
223
+ print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
224
+ print(f" - FramePack需求: {framepack_needed_frames}")
225
+ print(f" - 最终生成: {max_needed_frames}")
226
+
227
+ relative_poses = []
228
+ for i in range(max_needed_frames):
229
+ # 计算当前帧在原始序列中的位置
230
+ frame_idx = i * time_compression_ratio
231
+ next_frame_idx = frame_idx + time_compression_ratio
232
+
233
+ if next_frame_idx < len(cam_extrinsic):
234
+ cam_prev = cam_extrinsic[frame_idx]
235
+ cam_next = cam_extrinsic[next_frame_idx]
236
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
237
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
238
+ else:
239
+ # 超出范围,使用零运动
240
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
241
+ relative_poses.append(torch.zeros(3, 4))
242
+
243
+ pose_embedding = torch.stack(relative_poses, dim=0)
244
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
245
+
246
+ # 创建对应长度的mask序列
247
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
248
+ # 从start_frame到start_frame+initial_condition_frames标记为condition
249
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
250
+ mask[start_frame:condition_end] = 1.0
251
+
252
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
253
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
254
+ return camera_embedding.to(torch.bfloat16)
255
+
256
+ else:
257
+ # 确保生成足够长的camera序列
258
+ max_needed_frames = max(
259
+ start_frame + initial_condition_frames + new_frames,
260
+ framepack_needed_frames,
261
+ 30)
262
+
263
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
264
+
265
+ CONDITION_FRAMES = initial_condition_frames
266
+ STAGE_1 = new_frames//2
267
+ STAGE_2 = new_frames - STAGE_1
268
+
269
+ if direction=="left":
270
+ print("--------------- LEFT TURNING MODE ---------------")
271
+ relative_poses = []
272
+ for i in range(max_needed_frames):
273
+ if i < CONDITION_FRAMES:
274
+ # 输入的条件帧默认的相机位姿为零运动
275
+ pose = np.eye(4, dtype=np.float32)
276
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
277
+ # 左转
278
+ yaw_per_frame = 0.03
279
+
280
+ # 旋转矩阵
281
+ cos_yaw = np.cos(yaw_per_frame)
282
+ sin_yaw = np.sin(yaw_per_frame)
283
+
284
+ # 前进
285
+ forward_speed = 0.00
286
+
287
+ pose = np.eye(4, dtype=np.float32)
288
+
289
+ pose[0, 0] = cos_yaw
290
+ pose[0, 2] = sin_yaw
291
+ pose[2, 0] = -sin_yaw
292
+ pose[2, 2] = cos_yaw
293
+ pose[2, 3] = -forward_speed
294
+ else:
295
+ # 超出条件帧与目标帧的部分,保持静止
296
+ pose = np.eye(4, dtype=np.float32)
297
+
298
+ relative_pose = pose[:3, :]
299
+ relative_poses.append(torch.as_tensor(relative_pose))
300
+
301
+ elif direction=="right":
302
+ print("--------------- RIGHT TURNING MODE ---------------")
303
+ relative_poses = []
304
+ for i in range(max_needed_frames):
305
+ if i < CONDITION_FRAMES:
306
+ # 输入的条件帧默认的相机位姿为零运动
307
+ pose = np.eye(4, dtype=np.float32)
308
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
309
+ # 右转
310
+ yaw_per_frame = -0.03
311
+
312
+ # 旋转矩阵
313
+ cos_yaw = np.cos(yaw_per_frame)
314
+ sin_yaw = np.sin(yaw_per_frame)
315
+
316
+ # 前进
317
+ forward_speed = 0.00
318
+
319
+ pose = np.eye(4, dtype=np.float32)
320
+
321
+ pose[0, 0] = cos_yaw
322
+ pose[0, 2] = sin_yaw
323
+ pose[2, 0] = -sin_yaw
324
+ pose[2, 2] = cos_yaw
325
+ pose[2, 3] = -forward_speed
326
+ else:
327
+ # 超出条件帧与目标帧的部分,保持静止
328
+ pose = np.eye(4, dtype=np.float32)
329
+
330
+ relative_pose = pose[:3, :]
331
+ relative_poses.append(torch.as_tensor(relative_pose))
332
+
333
+ elif direction=="forward_left":
334
+ print("--------------- FORWARD LEFT MODE ---------------")
335
+ relative_poses = []
336
+ for i in range(max_needed_frames):
337
+ if i < CONDITION_FRAMES:
338
+ # 输入的条件帧默认的相机位姿为零运动
339
+ pose = np.eye(4, dtype=np.float32)
340
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
341
+ # 左转
342
+ yaw_per_frame = 0.03
343
+
344
+ # 旋转矩阵
345
+ cos_yaw = np.cos(yaw_per_frame)
346
+ sin_yaw = np.sin(yaw_per_frame)
347
+
348
+ # 前进
349
+ forward_speed = 0.03
350
+
351
+ pose = np.eye(4, dtype=np.float32)
352
+
353
+ pose[0, 0] = cos_yaw
354
+ pose[0, 2] = sin_yaw
355
+ pose[2, 0] = -sin_yaw
356
+ pose[2, 2] = cos_yaw
357
+ pose[2, 3] = -forward_speed
358
+
359
+ else:
360
+ # 超出条件帧与目标帧的部分,保持静止
361
+ pose = np.eye(4, dtype=np.float32)
362
+
363
+ relative_pose = pose[:3, :]
364
+ relative_poses.append(torch.as_tensor(relative_pose))
365
+
366
+ elif direction=="forward_right":
367
+ print("--------------- FORWARD RIGHT MODE ---------------")
368
+ relative_poses = []
369
+ for i in range(max_needed_frames):
370
+ if i < CONDITION_FRAMES:
371
+ # 输入的条件帧默认的相机位姿为零运动
372
+ pose = np.eye(4, dtype=np.float32)
373
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
374
+ # 右转
375
+ yaw_per_frame = -0.03
376
+
377
+ # 旋转矩阵
378
+ cos_yaw = np.cos(yaw_per_frame)
379
+ sin_yaw = np.sin(yaw_per_frame)
380
+
381
+ # 前进
382
+ forward_speed = 0.03
383
+
384
+ pose = np.eye(4, dtype=np.float32)
385
+
386
+ pose[0, 0] = cos_yaw
387
+ pose[0, 2] = sin_yaw
388
+ pose[2, 0] = -sin_yaw
389
+ pose[2, 2] = cos_yaw
390
+ pose[2, 3] = -forward_speed
391
+
392
+ else:
393
+ # 超出条件帧与目标帧的部分,保持静止
394
+ pose = np.eye(4, dtype=np.float32)
395
+
396
+ relative_pose = pose[:3, :]
397
+ relative_poses.append(torch.as_tensor(relative_pose))
398
+
399
+ elif direction=="s_curve":
400
+ print("--------------- S CURVE MODE ---------------")
401
+ relative_poses = []
402
+ for i in range(max_needed_frames):
403
+ if i < CONDITION_FRAMES:
404
+ # 输入的条件帧默认的相机位姿为零运动
405
+ pose = np.eye(4, dtype=np.float32)
406
+ elif i < CONDITION_FRAMES+STAGE_1:
407
+ # 左转
408
+ yaw_per_frame = 0.03
409
+
410
+ # 旋转矩阵
411
+ cos_yaw = np.cos(yaw_per_frame)
412
+ sin_yaw = np.sin(yaw_per_frame)
413
+
414
+ # 前进
415
+ forward_speed = 0.03
416
+
417
+ pose = np.eye(4, dtype=np.float32)
418
+
419
+ pose[0, 0] = cos_yaw
420
+ pose[0, 2] = sin_yaw
421
+ pose[2, 0] = -sin_yaw
422
+ pose[2, 2] = cos_yaw
423
+ pose[2, 3] = -forward_speed
424
+
425
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
426
+ # 右转
427
+ yaw_per_frame = -0.03
428
+
429
+ # 旋转矩阵
430
+ cos_yaw = np.cos(yaw_per_frame)
431
+ sin_yaw = np.sin(yaw_per_frame)
432
+
433
+ # 前进
434
+ forward_speed = 0.03
435
+ # 轻微向左漂移,保持惯性
436
+ if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
437
+ radius_shift = -0.01
438
+ else:
439
+ radius_shift = 0.00
440
+
441
+ pose = np.eye(4, dtype=np.float32)
442
+
443
+ pose[0, 0] = cos_yaw
444
+ pose[0, 2] = sin_yaw
445
+ pose[2, 0] = -sin_yaw
446
+ pose[2, 2] = cos_yaw
447
+ pose[2, 3] = -forward_speed
448
+ pose[0, 3] = radius_shift
449
+
450
+ else:
451
+ # 超出条件帧与目标帧的部分,保持静止
452
+ pose = np.eye(4, dtype=np.float32)
453
+
454
+ relative_pose = pose[:3, :]
455
+ relative_poses.append(torch.as_tensor(relative_pose))
456
+
457
+ elif direction=="left_right":
458
+ print("--------------- LEFT RIGHT MODE ---------------")
459
+ relative_poses = []
460
+ for i in range(max_needed_frames):
461
+ if i < CONDITION_FRAMES:
462
+ # 输入的条件帧默认的相机位姿为零运动
463
+ pose = np.eye(4, dtype=np.float32)
464
+ elif i < CONDITION_FRAMES+STAGE_1:
465
+ # 左转
466
+ yaw_per_frame = 0.03
467
+
468
+ # 旋转矩阵
469
+ cos_yaw = np.cos(yaw_per_frame)
470
+ sin_yaw = np.sin(yaw_per_frame)
471
+
472
+ # 前进
473
+ forward_speed = 0.00
474
+
475
+ pose = np.eye(4, dtype=np.float32)
476
+
477
+ pose[0, 0] = cos_yaw
478
+ pose[0, 2] = sin_yaw
479
+ pose[2, 0] = -sin_yaw
480
+ pose[2, 2] = cos_yaw
481
+ pose[2, 3] = -forward_speed
482
+
483
+ elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
484
+ # 右转
485
+ yaw_per_frame = -0.03
486
+
487
+ # 旋转矩阵
488
+ cos_yaw = np.cos(yaw_per_frame)
489
+ sin_yaw = np.sin(yaw_per_frame)
490
+
491
+ # 前进
492
+ forward_speed = 0.00
493
+
494
+ pose = np.eye(4, dtype=np.float32)
495
+
496
+ pose[0, 0] = cos_yaw
497
+ pose[0, 2] = sin_yaw
498
+ pose[2, 0] = -sin_yaw
499
+ pose[2, 2] = cos_yaw
500
+ pose[2, 3] = -forward_speed
501
+
502
+ else:
503
+ # 超出条件帧与目标帧的部分,保持静止
504
+ pose = np.eye(4, dtype=np.float32)
505
+
506
+ relative_pose = pose[:3, :]
507
+ relative_poses.append(torch.as_tensor(relative_pose))
508
+
509
+ else:
510
+ raise ValueError(f"未定义的相机运动方向: {direction}")
511
+
512
+ pose_embedding = torch.stack(relative_poses, dim=0)
513
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
514
+
515
+ # 创建对应长度的mask序列
516
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
517
+ condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
518
+ mask[start_frame:condition_end] = 1.0
519
+
520
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
521
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
522
+ return camera_embedding.to(torch.bfloat16)
523
+
524
+
525
+ def generate_openx_camera_embeddings_sliding(
526
+ encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
527
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
528
+ time_compression_ratio = 4
529
+
530
+ # 计算FramePack实际需要的camera帧数
531
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
532
+
533
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
534
+ print("🔧 使用OpenX真实camera数据")
535
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
536
+
537
+ # 确保生成足够长的camera序列
538
+ max_needed_frames = max(
539
+ start_frame + initial_condition_frames + new_frames,
540
+ framepack_needed_frames,
541
+ 30
542
+ )
543
+
544
+ print(f"🔧 计算OpenX camera序列长度:")
545
+ print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
546
+ print(f" - FramePack需求: {framepack_needed_frames}")
547
+ print(f" - 最终生成: {max_needed_frames}")
548
+
549
+ relative_poses = []
550
+ for i in range(max_needed_frames):
551
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
552
+ frame_idx = i * time_compression_ratio
553
+ next_frame_idx = frame_idx + time_compression_ratio
554
+
555
+ if next_frame_idx < len(cam_extrinsic):
556
+ cam_prev = cam_extrinsic[frame_idx]
557
+ cam_next = cam_extrinsic[next_frame_idx]
558
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
559
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
560
+ else:
561
+ # 超出范围,使用零运动
562
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
563
+ relative_poses.append(torch.zeros(3, 4))
564
+
565
+ pose_embedding = torch.stack(relative_poses, dim=0)
566
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
567
+
568
+ # 创建对应长度的mask序列
569
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
570
+ # 从start_frame到start_frame + initial_condition_frames标记为condition
571
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
572
+ mask[start_frame:condition_end] = 1.0
573
+
574
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
575
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
576
+ return camera_embedding.to(torch.bfloat16)
577
+
578
+ else:
579
+ print("🔧 使用OpenX合成camera数据")
580
+
581
+ max_needed_frames = max(
582
+ start_frame + initial_condition_frames + new_frames,
583
+ framepack_needed_frames,
584
+ 30
585
+ )
586
+
587
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
588
+ relative_poses = []
589
+ for i in range(max_needed_frames):
590
+ # OpenX机器人操作运动模式 - 较小的运动幅度
591
+ # 模拟机器人手臂的精细操作运动
592
+ roll_per_frame = 0.02 # 轻微翻滚
593
+ pitch_per_frame = 0.01 # 轻微俯仰
594
+ yaw_per_frame = 0.015 # 轻微偏航
595
+ forward_speed = 0.003 # 较慢的前进速度
596
+
597
+ pose = np.eye(4, dtype=np.float32)
598
+
599
+ # 复合旋转 - 模拟机器人手臂的复杂运动
600
+ # 绕X轴旋转(roll)
601
+ cos_roll = np.cos(roll_per_frame)
602
+ sin_roll = np.sin(roll_per_frame)
603
+ # 绕Y轴旋转(pitch��
604
+ cos_pitch = np.cos(pitch_per_frame)
605
+ sin_pitch = np.sin(pitch_per_frame)
606
+ # 绕Z轴旋转(yaw)
607
+ cos_yaw = np.cos(yaw_per_frame)
608
+ sin_yaw = np.sin(yaw_per_frame)
609
+
610
+ # 简化的复合旋转矩阵(ZYX顺序)
611
+ pose[0, 0] = cos_yaw * cos_pitch
612
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
613
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
614
+ pose[1, 0] = sin_yaw * cos_pitch
615
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
616
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
617
+ pose[2, 0] = -sin_pitch
618
+ pose[2, 1] = cos_pitch * sin_roll
619
+ pose[2, 2] = cos_pitch * cos_roll
620
+
621
+ # 平移 - 模拟机器人操作的精细移动
622
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
623
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
624
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
625
+
626
+ relative_pose = pose[:3, :]
627
+ relative_poses.append(torch.as_tensor(relative_pose))
628
+
629
+ pose_embedding = torch.stack(relative_poses, dim=0)
630
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
631
+
632
+ # 创建对应长度的mask序列
633
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
634
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
635
+ mask[start_frame:condition_end] = 1.0
636
+
637
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
638
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
639
+ return camera_embedding.to(torch.bfloat16)
640
+
641
+
642
+ def generate_nuscenes_camera_embeddings_sliding(
643
+ scene_info, start_frame, initial_condition_frames, new_frames):
644
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
645
+ time_compression_ratio = 4
646
+
647
+ # 计算FramePack实际需要的camera帧数
648
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
649
+
650
+ if scene_info is not None and 'keyframe_poses' in scene_info:
651
+ print("🔧 使用NuScenes真实pose数据")
652
+ keyframe_poses = scene_info['keyframe_poses']
653
+
654
+ if len(keyframe_poses) == 0:
655
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
656
+ max_needed_frames = max(framepack_needed_frames, 30)
657
+
658
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
659
+
660
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
661
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
662
+ mask[start_frame:condition_end] = 1.0
663
+
664
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
665
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
666
+ return camera_embedding.to(torch.bfloat16)
667
+
668
+ # 使用第一个pose作为参考
669
+ reference_pose = keyframe_poses[0]
670
+
671
+ max_needed_frames = max(framepack_needed_frames, 30)
672
+
673
+ pose_vecs = []
674
+ for i in range(max_needed_frames):
675
+ if i < len(keyframe_poses):
676
+ current_pose = keyframe_poses[i]
677
+
678
+ # 计算相对位移
679
+ translation = torch.tensor(
680
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
681
+ dtype=torch.float32
682
+ )
683
+
684
+ # 计算相对旋转(简化版本)
685
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
686
+
687
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
688
+ else:
689
+ # 超出范围,使用零pose
690
+ pose_vec = torch.cat([
691
+ torch.zeros(3, dtype=torch.float32),
692
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
693
+ ], dim=0) # [7D]
694
+
695
+ pose_vecs.append(pose_vec)
696
+
697
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
698
+
699
+ # 创建mask
700
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
701
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
702
+ mask[start_frame:condition_end] = 1.0
703
+
704
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
705
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
706
+ return camera_embedding.to(torch.bfloat16)
707
+
708
+ else:
709
+ print("🔧 使用NuScenes合成pose数据")
710
+ max_needed_frames = max(framepack_needed_frames, 30)
711
+
712
+ # 创建合成运动序列
713
+ pose_vecs = []
714
+ for i in range(max_needed_frames):
715
+ # 左转运动模式 - 类似城市驾驶中的左转弯
716
+ angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
717
+ radius = 15.0 # 较大的转弯半径,更符合汽车转弯
718
+
719
+ # 计算圆弧轨迹上的位置
720
+ x = radius * np.sin(angle)
721
+ y = 0.0 # 保持水平面运动
722
+ z = radius * (1 - np.cos(angle))
723
+
724
+ translation = torch.tensor([x, y, z], dtype=torch.float32)
725
+
726
+ # 车辆朝向 - 始终沿着轨迹切线方向
727
+ yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
728
+ # 四元数表示绕Y轴的旋转
729
+ rotation = torch.tensor([
730
+ np.cos(yaw/2), # w (实部)
731
+ 0.0, # x
732
+ 0.0, # y
733
+ np.sin(yaw/2) # z (虚部,绕Y轴)
734
+ ], dtype=torch.float32)
735
+
736
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
737
+ pose_vecs.append(pose_vec)
738
+
739
+ pose_sequence = torch.stack(pose_vecs, dim=0)
740
+
741
+ # 创建mask
742
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
743
+ condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
744
+ mask[start_frame:condition_end] = 1.0
745
+
746
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
747
+ print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
748
+ return camera_embedding.to(torch.bfloat16)
749
+
750
+ def prepare_framepack_sliding_window_with_camera_moe(
751
+ history_latents,
752
+ target_frames_to_generate,
753
+ camera_embedding_full,
754
+ start_frame,
755
+ modality_type,
756
+ max_history_frames=49):
757
+ """FramePack滑动窗口机制 - MoE版本"""
758
+ # history_latents: [C, T, H, W] 当前的历史latents
759
+ C, T, H, W = history_latents.shape
760
+
761
+ # 固定索引结构(这决定了需要的camera帧数)
762
+ # 1帧起始 + 16帧4x + 2帧2x + 1帧1x + target_frames_to_generate
763
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
764
+ indices = torch.arange(0, total_indices_length)
765
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
766
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
767
+ indices.split(split_sizes, dim=0)
768
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
769
+
770
+ # 检查camera长度是否足够
771
+ if camera_embedding_full.shape[0] < total_indices_length:
772
+ print(f"⚠️ camera_embedding长度不足,进行零补齐: 当前长度 {camera_embedding_full.shape[0]}, 需要长度 {total_indices_length}")
773
+ shortage = total_indices_length - camera_embedding_full.shape[0]
774
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
775
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
776
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
777
+
778
+ # 从完整camera序列中选取对应部分
779
+ combined_camera = torch.zeros(
780
+ total_indices_length,
781
+ camera_embedding_full.shape[1],
782
+ dtype=camera_embedding_full.dtype,
783
+ device=camera_embedding_full.device)
784
+
785
+ # 历史条件帧的相机位姿
786
+ history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
787
+ combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
788
+
789
+ # 目标帧的相机位姿
790
+ target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
791
+ combined_camera[19:19 + target_slice.shape[0], :] = target_slice
792
+
793
+ # 根据当前history length重新设置mask
794
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
795
+
796
+ # 设置condition mask:前19帧根据实际历史长度决定
797
+ if T > 0:
798
+ available_frames = min(T, 19)
799
+ start_pos = 19 - available_frames
800
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
801
+
802
+ print(f"🔧 MoE Camera mask更新:")
803
+ print(f" - 历史帧数: {T}")
804
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
805
+ print(f" - 模态类型: {modality_type}")
806
+
807
+ # 处理latents
808
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
809
+
810
+ if T > 0:
811
+ available_frames = min(T, 19)
812
+ start_pos = 19 - available_frames
813
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
814
+
815
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
816
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
817
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
818
+
819
+ if T > 0:
820
+ start_latent = history_latents[:, 0:1, :, :]
821
+ else:
822
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
823
+
824
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
825
+
826
+ return {
827
+ 'latent_indices': latent_indices,
828
+ 'clean_latents': clean_latents,
829
+ 'clean_latents_2x': clean_latents_2x,
830
+ 'clean_latents_4x': clean_latents_4x,
831
+ 'clean_latent_indices': clean_latent_indices,
832
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
833
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
834
+ 'camera_embedding': combined_camera,
835
+ 'modality_type': modality_type, # 新增模态类型信息
836
+ 'current_length': T,
837
+ 'next_length': T + target_frames_to_generate
838
+ }
839
+
840
+ def overlay_controls(frame_img, pose_vec, icons):
841
+ """
842
+ 根据相机位姿在帧上叠加控制图标(WASD 和箭头)
843
+ pose_vec: 12 个元素(展平的 3x4 矩阵)+ mask
844
+ """
845
+ if pose_vec is None or np.all(pose_vec[:12] == 0):
846
+ return frame_img
847
+
848
+ # 提取平移向量(基于展平的 3x4 矩阵的索引)
849
+ # [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
850
+ tx = pose_vec[3]
851
+ # ty = pose_vec[7]
852
+ tz = pose_vec[11]
853
+
854
+ # 提取旋转(偏航和俯仰)
855
+ # 偏航:绕 Y 轴。sin(偏航) = r02, cos(偏航) = r00
856
+ r00 = pose_vec[0]
857
+ r02 = pose_vec[2]
858
+ yaw = np.arctan2(r02, r00)
859
+
860
+ # 俯仰:绕 X 轴。sin(俯仰) = -r12, cos(俯仰) = r22
861
+ r12 = pose_vec[6]
862
+ r22 = pose_vec[10]
863
+ pitch = np.arctan2(-r12, r22)
864
+
865
+ # 按键激活的阈值
866
+ TRANS_THRESH = 0.01
867
+ ROT_THRESH = 0.005
868
+
869
+ # 确定按键状态
870
+ # 平移(WASD)
871
+ # 假设 -Z 为前进,+X 为右
872
+ is_forward = tz < -TRANS_THRESH
873
+ is_backward = tz > TRANS_THRESH
874
+ is_left = tx < -TRANS_THRESH
875
+ is_right = tx > TRANS_THRESH
876
+
877
+ # 旋转(箭头)
878
+ # 偏航:+ 为左,- 为右
879
+ is_turn_left = yaw > ROT_THRESH
880
+ is_turn_right = yaw < -ROT_THRESH
881
+
882
+ # 俯仰:+ 为下,- 为上
883
+ is_turn_up = pitch < -ROT_THRESH
884
+ is_turn_down = pitch > ROT_THRESH
885
+
886
+ W, H = frame_img.size
887
+ spacing = 60
888
+
889
+ def paste_icon(name_active, name_inactive, is_active, x, y):
890
+ name = name_active if is_active else name_inactive
891
+ if name in icons:
892
+ icon = icons[name]
893
+ # 使用 alpha 通道粘贴
894
+ frame_img.paste(icon, (int(x), int(y)), icon)
895
+
896
+ # 叠加 WASD(左下角)
897
+ base_x_right = 100
898
+ base_y = H - 100
899
+
900
+ # W
901
+ paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing)
902
+ # A
903
+ paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y)
904
+ # S
905
+ paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y)
906
+ # D
907
+ paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
908
+
909
+ # 叠加 ↑↓←→(右下角)
910
+ base_x_left = W - 150
911
+
912
+ # ↑
913
+ paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing)
914
+ # ←
915
+ paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y)
916
+ # ↓
917
+ paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y)
918
+ # →
919
+ paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y)
920
+
921
+ return frame_img
922
+
923
+
924
+ def inference_moe_framepack_sliding_window(
925
+ condition_pth_path,
926
+ dit_path,
927
+ output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
928
+ start_frame=0,
929
+ initial_condition_frames=8,
930
+ frames_per_generation=4,
931
+ total_frames_to_generate=32,
932
+ max_history_frames=49,
933
+ device="cuda",
934
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
935
+ modality_type="sekai", # "sekai" 或 "nuscenes"
936
+ use_real_poses=True,
937
+ scene_info_path=None, # 对于NuScenes数据集
938
+ # CFG参数
939
+ use_camera_cfg=True,
940
+ camera_guidance_scale=2.0,
941
+ text_guidance_scale=1.0,
942
+ # MoE参数
943
+ moe_num_experts=4,
944
+ moe_top_k=2,
945
+ moe_hidden_dim=None,
946
+ direction="left",
947
+ use_gt_prompt=True,
948
+ add_icons=False
949
+ ):
950
+ """
951
+ MoE FramePack滑动窗口视频生成 - 支持多模态
952
+ """
953
+ # 创建输出目录
954
+ dir_path = os.path.dirname(output_path)
955
+ os.makedirs(dir_path, exist_ok=True)
956
+
957
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
958
+ print(f"模态类型: {modality_type}")
959
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
960
+ print(f"Text guidance scale: {text_guidance_scale}")
961
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
962
+
963
+ # 1. 模型初始化
964
+ replace_dit_model_in_manager()
965
+
966
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
967
+ model_manager.load_models([
968
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
969
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
970
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
971
+ ])
972
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
973
+
974
+ # 2. 添加传统camera编码器(兼容性)
975
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
976
+ for block in pipe.dit.blocks:
977
+ block.cam_encoder = nn.Linear(13, dim)
978
+ block.projector = nn.Linear(dim, dim)
979
+ block.cam_encoder.weight.data.zero_()
980
+ block.cam_encoder.bias.data.zero_()
981
+ block.projector.weight = nn.Parameter(torch.eye(dim))
982
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
983
+
984
+ # 3. 添加FramePack组件
985
+ add_framepack_components(pipe.dit)
986
+
987
+ # 4. 添加MoE组件
988
+ moe_config = {
989
+ "num_experts": moe_num_experts,
990
+ "top_k": moe_top_k,
991
+ "hidden_dim": moe_hidden_dim or dim * 2,
992
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
993
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
994
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
995
+ }
996
+ add_moe_components(pipe.dit, moe_config)
997
+
998
+ # 5. 加载训练好的权重
999
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
1000
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
1001
+ pipe = pipe.to(device)
1002
+ model_dtype = next(pipe.dit.parameters()).dtype
1003
+
1004
+ if hasattr(pipe.dit, 'clean_x_embedder'):
1005
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
1006
+
1007
+ # 设置去噪步数
1008
+ pipe.scheduler.set_timesteps(50)
1009
+
1010
+ # 6. 加载初始条件
1011
+ print("Loading initial condition frames...")
1012
+ initial_latents, encoded_data = load_encoded_video_from_pth(
1013
+ condition_pth_path,
1014
+ start_frame=start_frame,
1015
+ num_frames=initial_condition_frames
1016
+ )
1017
+
1018
+ # 空间裁剪
1019
+ target_height, target_width = 60, 104
1020
+ C, T, H, W = initial_latents.shape
1021
+
1022
+ if H > target_height or W > target_width:
1023
+ h_start = (H - target_height) // 2
1024
+ w_start = (W - target_width) // 2
1025
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
1026
+ H, W = target_height, target_width
1027
+
1028
+ history_latents = initial_latents.to(device, dtype=model_dtype)
1029
+
1030
+ print(f"初始history_latents shape: {history_latents.shape}")
1031
+
1032
+ # 7. 编码prompt - 支持CFG
1033
+ if use_gt_prompt and 'prompt_emb' in encoded_data:
1034
+ print("✅ 使用预编码的GT prompt embedding")
1035
+ prompt_emb_pos = encoded_data['prompt_emb']
1036
+ # 将prompt_emb移到正确的设备和数据类型
1037
+ if 'context' in prompt_emb_pos:
1038
+ prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
1039
+ if 'context_mask' in prompt_emb_pos:
1040
+ prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
1041
+
1042
+ # 如果使用Text CFG,生成负向prompt
1043
+ if text_guidance_scale > 1.0:
1044
+ prompt_emb_neg = pipe.encode_prompt("")
1045
+ print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
1046
+ else:
1047
+ prompt_emb_neg = None
1048
+ print("不使用Text CFG")
1049
+
1050
+ # 🔧 打印GT prompt文本(如果有)
1051
+ if 'prompt' in encoded_data['prompt_emb']:
1052
+ gt_prompt_text = encoded_data['prompt_emb']['prompt']
1053
+ print(f"📝 GT Prompt文本: {gt_prompt_text}")
1054
+ else:
1055
+ # 使用传入的prompt参数重新编码
1056
+ print(f"🔄 重新编码prompt: {prompt}")
1057
+ if text_guidance_scale > 1.0:
1058
+ prompt_emb_pos = pipe.encode_prompt(prompt)
1059
+ prompt_emb_neg = pipe.encode_prompt("")
1060
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
1061
+ else:
1062
+ prompt_emb_pos = pipe.encode_prompt(prompt)
1063
+ prompt_emb_neg = None
1064
+ print("不使用Text CFG")
1065
+
1066
+ # 8. 加载场景信息(对于NuScenes)
1067
+ scene_info = None
1068
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
1069
+ with open(scene_info_path, 'r') as f:
1070
+ scene_info = json.load(f)
1071
+ print(f"加载NuScenes场景信息: {scene_info_path}")
1072
+
1073
+ # 9. 预生成完整的camera embedding序列
1074
+ if modality_type == "sekai":
1075
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
1076
+ encoded_data.get('cam_emb', None),
1077
+ start_frame,
1078
+ initial_condition_frames,
1079
+ total_frames_to_generate,
1080
+ 0,
1081
+ use_real_poses=use_real_poses,
1082
+ direction=direction
1083
+ ).to(device, dtype=model_dtype)
1084
+ elif modality_type == "nuscenes":
1085
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
1086
+ scene_info,
1087
+ start_frame,
1088
+ initial_condition_frames,
1089
+ total_frames_to_generate
1090
+ ).to(device, dtype=model_dtype)
1091
+ elif modality_type == "openx":
1092
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
1093
+ encoded_data,
1094
+ start_frame,
1095
+ initial_condition_frames,
1096
+ total_frames_to_generate,
1097
+ use_real_poses=use_real_poses
1098
+ ).to(device, dtype=model_dtype)
1099
+ else:
1100
+ raise ValueError(f"不支持的模态类型: {modality_type}")
1101
+
1102
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
1103
+
1104
+ # 10. 为Camera CFG创建无条件的camera embedding
1105
+ if use_camera_cfg:
1106
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
1107
+ print(f"创建无条件camera embedding用于CFG")
1108
+
1109
+ # 11. 滑动窗口生成循环
1110
+ total_generated = 0
1111
+ all_generated_frames = []
1112
+
1113
+ while total_generated < total_frames_to_generate:
1114
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
1115
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
1116
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
1117
+
1118
+ # FramePack数据准备 - MoE版本
1119
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
1120
+ history_latents,
1121
+ current_generation,
1122
+ camera_embedding_full,
1123
+ start_frame,
1124
+ modality_type,
1125
+ max_history_frames
1126
+ )
1127
+
1128
+ # 准备输入
1129
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
1130
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
1131
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
1132
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
1133
+
1134
+ # 准备modality_inputs
1135
+ modality_inputs = {modality_type: camera_embedding}
1136
+
1137
+ # 为CFG准备无条件camera embedding
1138
+ if use_camera_cfg:
1139
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
1140
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
1141
+
1142
+ # 索引处理
1143
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
1144
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
1145
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
1146
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
1147
+
1148
+ # 初始化要生成的latents
1149
+ new_latents = torch.randn(
1150
+ 1, C, current_generation, H, W,
1151
+ device=device, dtype=model_dtype
1152
+ )
1153
+
1154
+ extra_input = pipe.prepare_extra_input(new_latents)
1155
+
1156
+ print(f"Camera embedding shape: {camera_embedding.shape}")
1157
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
1158
+
1159
+ # 去噪循环 - 支持CFG
1160
+ timesteps = pipe.scheduler.timesteps
1161
+
1162
+ for i, timestep in enumerate(timesteps):
1163
+ if i % 10 == 0:
1164
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
1165
+
1166
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
1167
+
1168
+ with torch.no_grad():
1169
+ # CFG推理
1170
+ if use_camera_cfg and camera_guidance_scale > 1.0:
1171
+ # 条件预测(有camera)
1172
+ noise_pred_cond, moe_loess = pipe.dit(
1173
+ new_latents,
1174
+ timestep=timestep_tensor,
1175
+ cam_emb=camera_embedding,
1176
+ modality_inputs=modality_inputs, # MoE模态输入
1177
+ latent_indices=latent_indices,
1178
+ clean_latents=clean_latents,
1179
+ clean_latent_indices=clean_latent_indices,
1180
+ clean_latents_2x=clean_latents_2x,
1181
+ clean_latent_2x_indices=clean_latent_2x_indices,
1182
+ clean_latents_4x=clean_latents_4x,
1183
+ clean_latent_4x_indices=clean_latent_4x_indices,
1184
+ **prompt_emb_pos,
1185
+ **extra_input
1186
+ )
1187
+
1188
+ # 无条件预测(无camera)
1189
+ noise_pred_uncond, moe_loess = pipe.dit(
1190
+ new_latents,
1191
+ timestep=timestep_tensor,
1192
+ cam_emb=camera_embedding_uncond_batch,
1193
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
1194
+ latent_indices=latent_indices,
1195
+ clean_latents=clean_latents,
1196
+ clean_latent_indices=clean_latent_indices,
1197
+ clean_latents_2x=clean_latents_2x,
1198
+ clean_latent_2x_indices=clean_latent_2x_indices,
1199
+ clean_latents_4x=clean_latents_4x,
1200
+ clean_latent_4x_indices=clean_latent_4x_indices,
1201
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
1202
+ **extra_input
1203
+ )
1204
+
1205
+ # Camera CFG
1206
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1207
+
1208
+ # 如果同时使用Text CFG
1209
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
1210
+ noise_pred_text_uncond, moe_loess = pipe.dit(
1211
+ new_latents,
1212
+ timestep=timestep_tensor,
1213
+ cam_emb=camera_embedding,
1214
+ modality_inputs=modality_inputs,
1215
+ latent_indices=latent_indices,
1216
+ clean_latents=clean_latents,
1217
+ clean_latent_indices=clean_latent_indices,
1218
+ clean_latents_2x=clean_latents_2x,
1219
+ clean_latent_2x_indices=clean_latent_2x_indices,
1220
+ clean_latents_4x=clean_latents_4x,
1221
+ clean_latent_4x_indices=clean_latent_4x_indices,
1222
+ **prompt_emb_neg,
1223
+ **extra_input
1224
+ )
1225
+
1226
+ # 应用Text CFG到已经应用Camera CFG的结果
1227
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
1228
+
1229
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
1230
+ # 只使用Text CFG
1231
+ noise_pred_cond, moe_loess = pipe.dit(
1232
+ new_latents,
1233
+ timestep=timestep_tensor,
1234
+ cam_emb=camera_embedding,
1235
+ modality_inputs=modality_inputs,
1236
+ latent_indices=latent_indices,
1237
+ clean_latents=clean_latents,
1238
+ clean_latent_indices=clean_latent_indices,
1239
+ clean_latents_2x=clean_latents_2x,
1240
+ clean_latent_2x_indices=clean_latent_2x_indices,
1241
+ clean_latents_4x=clean_latents_4x,
1242
+ clean_latent_4x_indices=clean_latent_4x_indices,
1243
+ **prompt_emb_pos,
1244
+ **extra_input
1245
+ )
1246
+
1247
+ noise_pred_uncond, moe_loess= pipe.dit(
1248
+ new_latents,
1249
+ timestep=timestep_tensor,
1250
+ cam_emb=camera_embedding,
1251
+ modality_inputs=modality_inputs,
1252
+ latent_indices=latent_indices,
1253
+ clean_latents=clean_latents,
1254
+ clean_latent_indices=clean_latent_indices,
1255
+ clean_latents_2x=clean_latents_2x,
1256
+ clean_latent_2x_indices=clean_latent_2x_indices,
1257
+ clean_latents_4x=clean_latents_4x,
1258
+ clean_latent_4x_indices=clean_latent_4x_indices,
1259
+ **prompt_emb_neg,
1260
+ **extra_input
1261
+ )
1262
+
1263
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1264
+
1265
+ else:
1266
+ # 标准推理(无CFG)
1267
+ noise_pred, moe_loess = pipe.dit(
1268
+ new_latents,
1269
+ timestep=timestep_tensor,
1270
+ cam_emb=camera_embedding,
1271
+ modality_inputs=modality_inputs, # MoE模态输入
1272
+ latent_indices=latent_indices,
1273
+ clean_latents=clean_latents,
1274
+ clean_latent_indices=clean_latent_indices,
1275
+ clean_latents_2x=clean_latents_2x,
1276
+ clean_latent_2x_indices=clean_latent_2x_indices,
1277
+ clean_latents_4x=clean_latents_4x,
1278
+ clean_latent_4x_indices=clean_latent_4x_indices,
1279
+ **prompt_emb_pos,
1280
+ **extra_input
1281
+ )
1282
+
1283
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
1284
+
1285
+ # 更新历史
1286
+ new_latents_squeezed = new_latents.squeeze(0)
1287
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
1288
+
1289
+ # 维护滑动窗口
1290
+ if history_latents.shape[1] > max_history_frames:
1291
+ first_frame = history_latents[:, 0:1, :, :]
1292
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
1293
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
1294
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
1295
+
1296
+ print(f"更新后history_latents shape: {history_latents.shape}")
1297
+
1298
+ all_generated_frames.append(new_latents_squeezed)
1299
+ total_generated += current_generation
1300
+
1301
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
1302
+
1303
+ # 12. 解码和保存
1304
+ print("\n🔧 解码生成的视频...")
1305
+
1306
+ all_generated = torch.cat(all_generated_frames, dim=1)
1307
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
1308
+
1309
+ print(f"最终视频shape: {final_video.shape}")
1310
+
1311
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
1312
+
1313
+ print(f"Saving video to {output_path} ...")
1314
+
1315
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
1316
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
1317
+ video_np = (video_np * 255).astype(np.uint8)
1318
+
1319
+ icons = {}
1320
+ video_camera_poses = None
1321
+ if add_icons:
1322
+ # 加载用于叠加的图标资源
1323
+ icons_dir = os.path.join(ROOT_DIR, 'icons')
1324
+ icon_names = ['move_forward.png', 'not_move_forward.png',
1325
+ 'move_backward.png', 'not_move_backward.png',
1326
+ 'move_left.png', 'not_move_left.png',
1327
+ 'move_right.png', 'not_move_right.png',
1328
+ 'turn_up.png', 'not_turn_up.png',
1329
+ 'turn_down.png', 'not_turn_down.png',
1330
+ 'turn_left.png', 'not_turn_left.png',
1331
+ 'turn_right.png', 'not_turn_right.png']
1332
+ for name in icon_names:
1333
+ path = os.path.join(icons_dir, name)
1334
+ if os.path.exists(path):
1335
+ try:
1336
+ icon = Image.open(path).convert("RGBA")
1337
+ # 调整图标尺寸
1338
+ icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
1339
+ icons[name] = icon
1340
+ except Exception as e:
1341
+ print(f"Error loading icon {name}: {e}")
1342
+ else:
1343
+ print(f"Warning: Icon {name} not found at {path}")
1344
+
1345
+ # 获取与视频帧对应的相机姿态
1346
+ time_compression_ratio = 4
1347
+ camera_poses = camera_embedding_full.detach().float().cpu().numpy()
1348
+ video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
1349
+
1350
+ with imageio.get_writer(output_path, fps=20) as writer:
1351
+ for i, frame in enumerate(video_np):
1352
+ # Convert to PIL for overlay
1353
+ img = Image.fromarray(frame)
1354
+
1355
+ if add_icons and video_camera_poses is not None and icons:
1356
+ # Video frame i corresponds to camera_embedding_full[start_frame + i]
1357
+ pose_idx = start_frame + i
1358
+ if pose_idx < len(video_camera_poses):
1359
+ pose_vec = video_camera_poses[pose_idx]
1360
+ img = overlay_controls(img, pose_vec, icons)
1361
+
1362
+ writer.append_data(np.array(img))
1363
+
1364
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
1365
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
1366
+ print(f"使用模态: {modality_type}")
1367
+
1368
+
1369
+ def main():
1370
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
1371
+
1372
+ # 基础参数
1373
+ parser.add_argument("--condition_pth", type=str,
1374
+ default="../examples/condition_pth/garden_1.pth")
1375
+ parser.add_argument("--start_frame", type=int, default=0)
1376
+ parser.add_argument("--initial_condition_frames", type=int, default=1)
1377
+ parser.add_argument("--frames_per_generation", type=int, default=8)
1378
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
1379
+ parser.add_argument("--max_history_frames", type=int, default=100)
1380
+ parser.add_argument("--use_real_poses", default=False)
1381
+ parser.add_argument("--dit_path", type=str, default=None, required=True,
1382
+ help="path to the pretrained DiT MoE model checkpoint")
1383
+ parser.add_argument("--output_path", type=str,
1384
+ default='./examples/output_videos/output_moe_framepack_sliding.mp4')
1385
+ parser.add_argument("--prompt", type=str, default=None,
1386
+ help="text prompt for video generation")
1387
+ parser.add_argument("--device", type=str, default="cuda")
1388
+ parser.add_argument("--add_icons", action="store_true", default=False,
1389
+ help="在生成的视频上叠加控制图标")
1390
+
1391
+ # 模态类型参数
1392
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
1393
+ default="sekai", help="模态类型:sekai 或 nuscenes 或 openx")
1394
+ parser.add_argument("--scene_info_path", type=str, default=None,
1395
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
1396
+
1397
+ # CFG参数
1398
+ parser.add_argument("--use_camera_cfg", default=False,
1399
+ help="使用Camera CFG")
1400
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
1401
+ help="Camera guidance scale for CFG")
1402
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
1403
+ help="Text guidance scale for CFG")
1404
+
1405
+ # MoE参数
1406
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
1407
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
1408
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
1409
+ parser.add_argument("--direction", type=str, default="left", help="生成视频的行进轨迹方向")
1410
+ parser.add_argument("--use_gt_prompt", action="store_true", default=False,
1411
+ help="使用数据集中的ground truth prompt embedding")
1412
+
1413
+ args = parser.parse_args()
1414
+
1415
+ print(f"🔧 MoE FramePack CFG生成设置:")
1416
+ print(f"模态类型: {args.modality_type}")
1417
+ print(f"Camera CFG: {args.use_camera_cfg}")
1418
+ if args.use_camera_cfg:
1419
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
1420
+ print(f"使用GT Prompt: {args.use_gt_prompt}")
1421
+ print(f"Text guidance scale: {args.text_guidance_scale}")
1422
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
1423
+ print(f"DiT{args.dit_path}")
1424
+
1425
+ # 验证NuScenes参数
1426
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
1427
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
1428
+
1429
+ inference_moe_framepack_sliding_window(
1430
+ condition_pth_path=args.condition_pth,
1431
+ dit_path=args.dit_path,
1432
+ output_path=args.output_path,
1433
+ start_frame=args.start_frame,
1434
+ initial_condition_frames=args.initial_condition_frames,
1435
+ frames_per_generation=args.frames_per_generation,
1436
+ total_frames_to_generate=args.total_frames_to_generate,
1437
+ max_history_frames=args.max_history_frames,
1438
+ device=args.device,
1439
+ prompt=args.prompt,
1440
+ modality_type=args.modality_type,
1441
+ use_real_poses=args.use_real_poses,
1442
+ scene_info_path=args.scene_info_path,
1443
+ # CFG参数
1444
+ use_camera_cfg=args.use_camera_cfg,
1445
+ camera_guidance_scale=args.camera_guidance_scale,
1446
+ text_guidance_scale=args.text_guidance_scale,
1447
+ # MoE参数
1448
+ moe_num_experts=args.moe_num_experts,
1449
+ moe_top_k=args.moe_top_k,
1450
+ moe_hidden_dim=args.moe_hidden_dim,
1451
+ direction=args.direction,
1452
+ use_gt_prompt=args.use_gt_prompt,
1453
+ add_icons=args.add_icons
1454
+ )
1455
+
1456
+
1457
+ if __name__ == "__main__":
1458
+ main()
scripts/infer_moe.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+ from scipy.spatial.transform import Rotation as R
14
+
15
+
16
+ def compute_relative_pose_matrix(pose1, pose2):
17
+ """
18
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
19
+
20
+ 参数:
21
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
22
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
23
+
24
+ 返回:
25
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
26
+ """
27
+ # 分离平移向量和四元数
28
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
29
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
30
+ t2 = pose2[:3] # 第i+1帧平移
31
+ q2 = pose2[3:] # 第i+1帧四元数
32
+
33
+ # 1. 计算相对旋转矩阵 R_rel
34
+ rot1 = R.from_quat(q1) # 第i帧旋转
35
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
36
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
37
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
38
+
39
+ # 2. 计算相对平移向量 t_rel
40
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
41
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
42
+
43
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
44
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
45
+
46
+ return relative_matrix
47
+
48
+
49
+ def calculate_relative_rotation(current_rotation, reference_rotation):
50
+ """计算相对旋转四元数 - NuScenes专用"""
51
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
52
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
53
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
54
+ w1, x1, y1, z1 = q_ref_inv
55
+ w2, x2, y2, z2 = q_current
56
+ relative_rotation = torch.tensor([
57
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
58
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
59
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
60
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
61
+ ])
62
+ return relative_rotation
63
+
64
+
65
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
66
+ """从pth文件加载预编码的视频数据"""
67
+ print(f"Loading encoded video from {pth_path}")
68
+
69
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
70
+ full_latents = encoded_data['latents'] # [C, T, H, W]
71
+
72
+ print(f"Full latents shape: {full_latents.shape}")
73
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
74
+
75
+ if start_frame + num_frames > full_latents.shape[1]:
76
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
77
+
78
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
79
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
80
+
81
+ return condition_latents, encoded_data
82
+
83
+
84
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
85
+ """计算相机B相对于相机A的相对位姿矩阵"""
86
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
87
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
88
+
89
+ if use_torch:
90
+ if not isinstance(pose_a, torch.Tensor):
91
+ pose_a = torch.from_numpy(pose_a).float()
92
+ if not isinstance(pose_b, torch.Tensor):
93
+ pose_b = torch.from_numpy(pose_b).float()
94
+
95
+ pose_a_inv = torch.inverse(pose_a)
96
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
97
+ else:
98
+ if not isinstance(pose_a, np.ndarray):
99
+ pose_a = np.array(pose_a, dtype=np.float32)
100
+ if not isinstance(pose_b, np.ndarray):
101
+ pose_b = np.array(pose_b, dtype=np.float32)
102
+
103
+ pose_a_inv = np.linalg.inv(pose_a)
104
+ relative_pose = np.matmul(pose_b, pose_a_inv)
105
+
106
+ return relative_pose
107
+
108
+
109
+ def replace_dit_model_in_manager():
110
+ """替换DiT模型类为MoE版本"""
111
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
112
+ from diffsynth.configs.model_config import model_loader_configs
113
+
114
+ for i, config in enumerate(model_loader_configs):
115
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
116
+
117
+ if 'wan_video_dit' in model_names:
118
+ new_model_names = []
119
+ new_model_classes = []
120
+
121
+ for name, cls in zip(model_names, model_classes):
122
+ if name == 'wan_video_dit':
123
+ new_model_names.append(name)
124
+ new_model_classes.append(WanModelMoe)
125
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
126
+ else:
127
+ new_model_names.append(name)
128
+ new_model_classes.append(cls)
129
+
130
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
131
+
132
+
133
+ def add_framepack_components(dit_model):
134
+ """添加FramePack相关组件"""
135
+ if not hasattr(dit_model, 'clean_x_embedder'):
136
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
137
+
138
+ class CleanXEmbedder(nn.Module):
139
+ def __init__(self, inner_dim):
140
+ super().__init__()
141
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
142
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
143
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
144
+
145
+ def forward(self, x, scale="1x"):
146
+ if scale == "1x":
147
+ x = x.to(self.proj.weight.dtype)
148
+ return self.proj(x)
149
+ elif scale == "2x":
150
+ x = x.to(self.proj_2x.weight.dtype)
151
+ return self.proj_2x(x)
152
+ elif scale == "4x":
153
+ x = x.to(self.proj_4x.weight.dtype)
154
+ return self.proj_4x(x)
155
+ else:
156
+ raise ValueError(f"Unsupported scale: {scale}")
157
+
158
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
159
+ model_dtype = next(dit_model.parameters()).dtype
160
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
161
+ print("✅ 添加了FramePack的clean_x_embedder组件")
162
+
163
+
164
+ def add_moe_components(dit_model, moe_config):
165
+ """🔧 添加MoE相关组件 - 修正版本"""
166
+ if not hasattr(dit_model, 'moe_config'):
167
+ dit_model.moe_config = moe_config
168
+ print("✅ 添加了MoE配置到模型")
169
+ dit_model.top_k = moe_config.get("top_k", 1)
170
+
171
+ # 为每个block动态添加MoE组件
172
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
173
+ unified_dim = moe_config.get("unified_dim", 25)
174
+ num_experts = moe_config.get("num_experts", 4)
175
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
176
+ dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
177
+ dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
178
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
179
+ dit_model.global_router = nn.Linear(unified_dim, num_experts)
180
+
181
+
182
+ for i, block in enumerate(dit_model.blocks):
183
+ # MoE网络 - 输入unified_dim,输出dim
184
+ block.moe = MultiModalMoE(
185
+ unified_dim=unified_dim,
186
+ output_dim=dim, # 输出维度匹配transformer block的dim
187
+ num_experts=moe_config.get("num_experts", 4),
188
+ top_k=moe_config.get("top_k", 2)
189
+ )
190
+
191
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
192
+
193
+
194
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
195
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
196
+ time_compression_ratio = 4
197
+
198
+ # 计算FramePack实际需要的camera帧数
199
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
200
+
201
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
202
+ print("🔧 使用真实Sekai camera数据")
203
+ cam_extrinsic = cam_data['extrinsic']
204
+
205
+ # 确保生成足够长的camera序列
206
+ max_needed_frames = max(
207
+ start_frame + current_history_length + new_frames,
208
+ framepack_needed_frames,
209
+ 30
210
+ )
211
+
212
+ print(f"🔧 计算Sekai camera序列长度:")
213
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
214
+ print(f" - FramePack需求: {framepack_needed_frames}")
215
+ print(f" - 最终生成: {max_needed_frames}")
216
+
217
+ relative_poses = []
218
+ for i in range(max_needed_frames):
219
+ # 计算当前帧在原始序列中的位置
220
+ frame_idx = i * time_compression_ratio
221
+ next_frame_idx = frame_idx + time_compression_ratio
222
+
223
+ if next_frame_idx < len(cam_extrinsic):
224
+ cam_prev = cam_extrinsic[frame_idx]
225
+ cam_next = cam_extrinsic[next_frame_idx]
226
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
227
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
228
+ else:
229
+ # 超出范围,使用零运动
230
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
231
+ relative_poses.append(torch.zeros(3, 4))
232
+
233
+ pose_embedding = torch.stack(relative_poses, dim=0)
234
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
235
+
236
+ # 创建对应长度的mask序列
237
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
238
+ # 从start_frame到current_history_length标记为condition
239
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
240
+ mask[start_frame:condition_end] = 1.0
241
+
242
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
243
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
244
+ return camera_embedding.to(torch.bfloat16)
245
+
246
+ else:
247
+ print("🔧 使用Sekai合成camera数据")
248
+
249
+ max_needed_frames = max(
250
+ start_frame + current_history_length + new_frames,
251
+ framepack_needed_frames,
252
+ 30
253
+ )
254
+
255
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
256
+ relative_poses = []
257
+ for i in range(max_needed_frames):
258
+ # 持续左转运动模式
259
+ yaw_per_frame = -0.1 # 每帧左转(正角度表示左转)
260
+ forward_speed = 0.005 # 每帧前进距离
261
+
262
+ pose = np.eye(4, dtype=np.float32)
263
+
264
+ # 旋转矩阵(绕Y轴左转)
265
+ cos_yaw = np.cos(yaw_per_frame)
266
+ sin_yaw = np.sin(yaw_per_frame)
267
+
268
+ pose[0, 0] = cos_yaw
269
+ pose[0, 2] = sin_yaw
270
+ pose[2, 0] = -sin_yaw
271
+ pose[2, 2] = cos_yaw
272
+
273
+ # 平移(在旋转后的局部坐标系中前进)
274
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
275
+
276
+ # 添加轻微的向心运动,模拟圆形轨迹
277
+ radius_drift = 0.002 # 向圆心的轻微漂移
278
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
279
+
280
+ relative_pose = pose[:3, :]
281
+ relative_poses.append(torch.as_tensor(relative_pose))
282
+
283
+ pose_embedding = torch.stack(relative_poses, dim=0)
284
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
285
+
286
+ # 创建对应长度的mask序列
287
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
288
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
289
+ mask[start_frame:condition_end] = 1.0
290
+
291
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
292
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
293
+ return camera_embedding.to(torch.bfloat16)
294
+
295
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
296
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
297
+ time_compression_ratio = 4
298
+
299
+ # 计算FramePack实际需要的camera帧数
300
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
301
+
302
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
303
+ print("🔧 使用OpenX真实camera数据")
304
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
305
+
306
+ # 确保生成足够长的camera序列
307
+ max_needed_frames = max(
308
+ start_frame + current_history_length + new_frames,
309
+ framepack_needed_frames,
310
+ 30
311
+ )
312
+
313
+ print(f"🔧 计算OpenX camera序列长度:")
314
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
315
+ print(f" - FramePack需求: {framepack_needed_frames}")
316
+ print(f" - 最终生成: {max_needed_frames}")
317
+
318
+ relative_poses = []
319
+ for i in range(max_needed_frames):
320
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
321
+ frame_idx = i * time_compression_ratio
322
+ next_frame_idx = frame_idx + time_compression_ratio
323
+
324
+ if next_frame_idx < len(cam_extrinsic):
325
+ cam_prev = cam_extrinsic[frame_idx]
326
+ cam_next = cam_extrinsic[next_frame_idx]
327
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
328
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
329
+ else:
330
+ # 超出范围,使用零运动
331
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
332
+ relative_poses.append(torch.zeros(3, 4))
333
+
334
+ pose_embedding = torch.stack(relative_poses, dim=0)
335
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
336
+
337
+ # 创建对应长度的mask序列
338
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
339
+ # 从start_frame到current_history_length标记为condition
340
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
341
+ mask[start_frame:condition_end] = 1.0
342
+
343
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
344
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
345
+ return camera_embedding.to(torch.bfloat16)
346
+
347
+ else:
348
+ print("🔧 使用OpenX合成camera数据")
349
+
350
+ max_needed_frames = max(
351
+ start_frame + current_history_length + new_frames,
352
+ framepack_needed_frames,
353
+ 30
354
+ )
355
+
356
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
357
+ relative_poses = []
358
+ for i in range(max_needed_frames):
359
+ # OpenX机器人操作运动模式 - 较小的运动幅度
360
+ # 模拟机器人手臂的精细操作运动
361
+ roll_per_frame = 0.02 # 轻微翻滚
362
+ pitch_per_frame = 0.01 # 轻微俯仰
363
+ yaw_per_frame = 0.015 # 轻微偏航
364
+ forward_speed = 0.003 # 较慢的前进速度
365
+
366
+ pose = np.eye(4, dtype=np.float32)
367
+
368
+ # 复合旋转 - 模拟机器人手臂的复杂运动
369
+ # 绕X轴旋转(roll)
370
+ cos_roll = np.cos(roll_per_frame)
371
+ sin_roll = np.sin(roll_per_frame)
372
+ # 绕Y轴旋转(pitch)
373
+ cos_pitch = np.cos(pitch_per_frame)
374
+ sin_pitch = np.sin(pitch_per_frame)
375
+ # 绕Z轴旋转(yaw)
376
+ cos_yaw = np.cos(yaw_per_frame)
377
+ sin_yaw = np.sin(yaw_per_frame)
378
+
379
+ # 简化的复合旋转矩阵(ZYX顺序)
380
+ pose[0, 0] = cos_yaw * cos_pitch
381
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
382
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
383
+ pose[1, 0] = sin_yaw * cos_pitch
384
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
385
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
386
+ pose[2, 0] = -sin_pitch
387
+ pose[2, 1] = cos_pitch * sin_roll
388
+ pose[2, 2] = cos_pitch * cos_roll
389
+
390
+ # 平移 - 模拟机器人操作的精细移动
391
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
392
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
393
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
394
+
395
+ relative_pose = pose[:3, :]
396
+ relative_poses.append(torch.as_tensor(relative_pose))
397
+
398
+ pose_embedding = torch.stack(relative_poses, dim=0)
399
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
400
+
401
+ # 创建对应长度的mask序列
402
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
403
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
404
+ mask[start_frame:condition_end] = 1.0
405
+
406
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
407
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
408
+ return camera_embedding.to(torch.bfloat16)
409
+
410
+
411
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
412
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
413
+ time_compression_ratio = 4
414
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
415
+ max_needed_frames = max(framepack_needed_frames, 30)
416
+
417
+ if scene_info is not None and 'keyframe_poses' in scene_info:
418
+ print("🔧 使用NuScenes真实pose数据")
419
+ keyframe_poses = scene_info['keyframe_poses']
420
+ # 生成所有需要的关键帧索引
421
+ keyframe_indices = []
422
+ for i in range(max_needed_frames + 1): # +1是因为需要前后两帧
423
+ idx = (start_frame + i) * time_compression_ratio
424
+ keyframe_indices.append(idx)
425
+ keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices]
426
+
427
+ pose_vecs = []
428
+ for i in range(max_needed_frames):
429
+ pose_prev = keyframe_poses[keyframe_indices[i]]
430
+ pose_next = keyframe_poses[keyframe_indices[i+1]]
431
+ # 计算相对位移
432
+ translation = torch.tensor(
433
+ np.array(pose_next['translation']) - np.array(pose_prev['translation']),
434
+ dtype=torch.float32
435
+ )
436
+ # 计算相对旋转
437
+ relative_rotation = calculate_relative_rotation(
438
+ pose_next['rotation'],
439
+ pose_prev['rotation']
440
+ )
441
+ pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
442
+ pose_vecs.append(pose_vec)
443
+
444
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
445
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
446
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
447
+ mask[start_frame:condition_end] = 1.0
448
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1)
449
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
450
+ return camera_embedding.to(torch.bfloat16)
451
+
452
+ else:
453
+ print("🔧 使用NuScenes合成pose数据")
454
+ # 先生成绝对轨迹
455
+ abs_translations = []
456
+ abs_rotations = []
457
+ for i in range(max_needed_frames + 1): # +1是为了后续做相对
458
+ angle = -i * 0.12
459
+ radius = 8.0
460
+ x = radius * np.sin(angle)
461
+ y = 0.0
462
+ z = radius * (1 - np.cos(angle))
463
+ abs_translations.append(np.array([x, y, z], dtype=np.float32))
464
+ yaw = angle + np.pi/2
465
+ abs_rotations.append(np.array([
466
+ np.cos(yaw/2), 0.0, 0.0, np.sin(yaw/2)
467
+ ], dtype=np.float32))
468
+
469
+ # 计算每帧相对上一帧的运动
470
+ pose_vecs = []
471
+ for i in range(max_needed_frames):
472
+ translation = torch.tensor(abs_translations[i+1] - abs_translations[i], dtype=torch.float32)
473
+ # 计算相对旋转
474
+ q_next = abs_rotations[i+1]
475
+ q_prev = abs_rotations[i]
476
+ # 四元数相对旋转
477
+ q_prev_inv = np.array([q_prev[0], -q_prev[1], -q_prev[2], -q_prev[3]], dtype=np.float32)
478
+ w1, x1, y1, z1 = q_prev_inv
479
+ w2, x2, y2, z2 = q_next
480
+ relative_rotation = torch.tensor([
481
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
482
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
483
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
484
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
485
+ ], dtype=torch.float32)
486
+ pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
487
+ pose_vecs.append(pose_vec)
488
+
489
+ pose_sequence = torch.stack(pose_vecs, dim=0)
490
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
491
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
492
+ mask[start_frame:condition_end] = 1.0
493
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1)
494
+ print(f"🔧 NuScenes合成相对pose embedding shape: {camera_embedding.shape}")
495
+ return camera_embedding.to(torch.bfloat16)
496
+
497
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
498
+ """FramePack滑动窗口机制 - MoE版本"""
499
+ # history_latents: [C, T, H, W] 当前的历史latents
500
+ C, T, H, W = history_latents.shape
501
+
502
+ # 固定索引结构(这决定了需要的camera帧数)
503
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
504
+ indices = torch.arange(0, total_indices_length)
505
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
506
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
507
+ indices.split(split_sizes, dim=0)
508
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
509
+
510
+ # 检查camera长度是否足够
511
+ if camera_embedding_full.shape[0] < total_indices_length:
512
+ shortage = total_indices_length - camera_embedding_full.shape[0]
513
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
514
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
515
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
516
+
517
+ # 从完整camera序列中选取对应部分
518
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
519
+
520
+ # 根据当前history length重新设置mask
521
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
522
+
523
+ # 设置condition mask:前19帧根据实际历史长度决定
524
+ if T > 0:
525
+ available_frames = min(T, 19)
526
+ start_pos = 19 - available_frames
527
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
528
+
529
+ print(f"🔧 MoE Camera mask更新:")
530
+ print(f" - 历史帧数: {T}")
531
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
532
+ print(f" - 模态类型: {modality_type}")
533
+
534
+ # 处理latents
535
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
536
+
537
+ if T > 0:
538
+ available_frames = min(T, 19)
539
+ start_pos = 19 - available_frames
540
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
541
+
542
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
543
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
544
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
545
+
546
+ if T > 0:
547
+ start_latent = history_latents[:, 0:1, :, :]
548
+ else:
549
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
550
+
551
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
552
+
553
+ return {
554
+ 'latent_indices': latent_indices,
555
+ 'clean_latents': clean_latents,
556
+ 'clean_latents_2x': clean_latents_2x,
557
+ 'clean_latents_4x': clean_latents_4x,
558
+ 'clean_latent_indices': clean_latent_indices,
559
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
560
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
561
+ 'camera_embedding': combined_camera,
562
+ 'modality_type': modality_type, # 新增模态类型信息
563
+ 'current_length': T,
564
+ 'next_length': T + target_frames_to_generate
565
+ }
566
+
567
+
568
+ def inference_moe_framepack_sliding_window(
569
+ condition_pth_path,
570
+ dit_path,
571
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
572
+ start_frame=0,
573
+ initial_condition_frames=8,
574
+ frames_per_generation=4,
575
+ total_frames_to_generate=32,
576
+ max_history_frames=49,
577
+ device="cuda",
578
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
579
+ modality_type="sekai", # "sekai" 或 "nuscenes"
580
+ use_real_poses=True,
581
+ scene_info_path=None, # 对于NuScenes数据集
582
+ # CFG参数
583
+ use_camera_cfg=True,
584
+ camera_guidance_scale=2.0,
585
+ text_guidance_scale=1.0,
586
+ # MoE参数
587
+ moe_num_experts=4,
588
+ moe_top_k=2,
589
+ moe_hidden_dim=None
590
+ ):
591
+ """
592
+ MoE FramePack滑动窗口视频生成 - 支持多模态
593
+ """
594
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
595
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
596
+ print(f"模态类型: {modality_type}")
597
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
598
+ print(f"Text guidance scale: {text_guidance_scale}")
599
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
600
+
601
+ # 1. 模型初始化
602
+ replace_dit_model_in_manager()
603
+
604
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
605
+ model_manager.load_models([
606
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
607
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
608
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
609
+ ])
610
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
611
+
612
+ # 2. 添加传统camera编码器(兼容性)
613
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
614
+ for block in pipe.dit.blocks:
615
+ block.cam_encoder = nn.Linear(13, dim)
616
+ block.projector = nn.Linear(dim, dim)
617
+ block.cam_encoder.weight.data.zero_()
618
+ block.cam_encoder.bias.data.zero_()
619
+ block.projector.weight = nn.Parameter(torch.eye(dim))
620
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
621
+
622
+ # 3. 添加FramePack组件
623
+ add_framepack_components(pipe.dit)
624
+
625
+ # 4. 添加MoE组件
626
+ moe_config = {
627
+ "num_experts": moe_num_experts,
628
+ "top_k": moe_top_k,
629
+ "hidden_dim": moe_hidden_dim or dim * 2,
630
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
631
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
632
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
633
+ }
634
+ add_moe_components(pipe.dit, moe_config)
635
+
636
+ # 5. 加载训练好的权重
637
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
638
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
639
+ pipe = pipe.to(device)
640
+ model_dtype = next(pipe.dit.parameters()).dtype
641
+
642
+ if hasattr(pipe.dit, 'clean_x_embedder'):
643
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
644
+
645
+ pipe.scheduler.set_timesteps(50)
646
+
647
+ # 6. 加载初始条件
648
+ print("Loading initial condition frames...")
649
+ initial_latents, encoded_data = load_encoded_video_from_pth(
650
+ condition_pth_path,
651
+ start_frame=start_frame,
652
+ num_frames=initial_condition_frames
653
+ )
654
+
655
+ # 空间裁剪
656
+ target_height, target_width = 60, 104
657
+ C, T, H, W = initial_latents.shape
658
+
659
+ if H > target_height or W > target_width:
660
+ h_start = (H - target_height) // 2
661
+ w_start = (W - target_width) // 2
662
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
663
+ H, W = target_height, target_width
664
+
665
+ history_latents = initial_latents.to(device, dtype=model_dtype)
666
+
667
+ print(f"初始history_latents shape: {history_latents.shape}")
668
+
669
+ # 7. 编码prompt - 支持CFG
670
+ if text_guidance_scale > 1.0:
671
+ prompt_emb_pos = pipe.encode_prompt(prompt)
672
+ prompt_emb_neg = pipe.encode_prompt("")
673
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
674
+ else:
675
+ prompt_emb_pos = pipe.encode_prompt(prompt)
676
+ prompt_emb_neg = None
677
+ print("不使用Text CFG")
678
+
679
+ # 8. 加载场景信息(对于NuScenes)
680
+ scene_info = None
681
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
682
+ with open(scene_info_path, 'r') as f:
683
+ scene_info = json.load(f)
684
+ print(f"加载NuScenes场景信息: {scene_info_path}")
685
+
686
+ # 9. 预生成完整的camera embedding序列
687
+ if modality_type == "sekai":
688
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
689
+ encoded_data.get('cam_emb', None),
690
+ 0,
691
+ max_history_frames,
692
+ 0,
693
+ 0,
694
+ use_real_poses=use_real_poses
695
+ ).to(device, dtype=model_dtype)
696
+ elif modality_type == "nuscenes":
697
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
698
+ scene_info,
699
+ 0,
700
+ max_history_frames,
701
+ 0
702
+ ).to(device, dtype=model_dtype)
703
+ elif modality_type == "openx":
704
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
705
+ encoded_data,
706
+ 0,
707
+ max_history_frames,
708
+ 0,
709
+ use_real_poses=use_real_poses
710
+ ).to(device, dtype=model_dtype)
711
+ else:
712
+ raise ValueError(f"不支持的模态类型: {modality_type}")
713
+
714
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
715
+
716
+ # 10. 为Camera CFG创建无条件的camera embedding
717
+ if use_camera_cfg:
718
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
719
+ print(f"创建无条件camera embedding用于CFG")
720
+
721
+ # 11. 滑动窗口生成循环
722
+ total_generated = 0
723
+ all_generated_frames = []
724
+
725
+ while total_generated < total_frames_to_generate:
726
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
727
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
728
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
729
+
730
+ # FramePack数据准备 - MoE版本
731
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
732
+ history_latents,
733
+ current_generation,
734
+ camera_embedding_full,
735
+ start_frame,
736
+ modality_type,
737
+ max_history_frames
738
+ )
739
+
740
+ # 准备输入
741
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
742
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
743
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
744
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
745
+
746
+ # 准备modality_inputs
747
+ modality_inputs = {modality_type: camera_embedding}
748
+
749
+ # 为CFG准备无条件camera embedding
750
+ if use_camera_cfg:
751
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
752
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
753
+
754
+ # 索引处理
755
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
756
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
757
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
758
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
759
+
760
+ # 初始化要生成的latents
761
+ new_latents = torch.randn(
762
+ 1, C, current_generation, H, W,
763
+ device=device, dtype=model_dtype
764
+ )
765
+
766
+ extra_input = pipe.prepare_extra_input(new_latents)
767
+
768
+ print(f"Camera embedding shape: {camera_embedding.shape}")
769
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
770
+
771
+ # 去噪循环 - 支持CFG
772
+ timesteps = pipe.scheduler.timesteps
773
+
774
+ for i, timestep in enumerate(timesteps):
775
+ if i % 10 == 0:
776
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
777
+
778
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
779
+
780
+ with torch.no_grad():
781
+ # CFG推理
782
+ if use_camera_cfg and camera_guidance_scale > 1.0:
783
+ # 条件预测(有camera)
784
+ noise_pred_cond, moe_loess = pipe.dit(
785
+ new_latents,
786
+ timestep=timestep_tensor,
787
+ cam_emb=camera_embedding,
788
+ modality_inputs=modality_inputs, # MoE模态输入
789
+ latent_indices=latent_indices,
790
+ clean_latents=clean_latents,
791
+ clean_latent_indices=clean_latent_indices,
792
+ clean_latents_2x=clean_latents_2x,
793
+ clean_latent_2x_indices=clean_latent_2x_indices,
794
+ clean_latents_4x=clean_latents_4x,
795
+ clean_latent_4x_indices=clean_latent_4x_indices,
796
+ **prompt_emb_pos,
797
+ **extra_input
798
+ )
799
+
800
+ # 无条件预测(无camera)
801
+ noise_pred_uncond, moe_loess = pipe.dit(
802
+ new_latents,
803
+ timestep=timestep_tensor,
804
+ cam_emb=camera_embedding_uncond_batch,
805
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
806
+ latent_indices=latent_indices,
807
+ clean_latents=clean_latents,
808
+ clean_latent_indices=clean_latent_indices,
809
+ clean_latents_2x=clean_latents_2x,
810
+ clean_latent_2x_indices=clean_latent_2x_indices,
811
+ clean_latents_4x=clean_latents_4x,
812
+ clean_latent_4x_indices=clean_latent_4x_indices,
813
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
814
+ **extra_input
815
+ )
816
+
817
+ # Camera CFG
818
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
819
+
820
+ # 如果同时使用Text CFG
821
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
822
+ noise_pred_text_uncond, moe_loess = pipe.dit(
823
+ new_latents,
824
+ timestep=timestep_tensor,
825
+ cam_emb=camera_embedding,
826
+ modality_inputs=modality_inputs,
827
+ latent_indices=latent_indices,
828
+ clean_latents=clean_latents,
829
+ clean_latent_indices=clean_latent_indices,
830
+ clean_latents_2x=clean_latents_2x,
831
+ clean_latent_2x_indices=clean_latent_2x_indices,
832
+ clean_latents_4x=clean_latents_4x,
833
+ clean_latent_4x_indices=clean_latent_4x_indices,
834
+ **prompt_emb_neg,
835
+ **extra_input
836
+ )
837
+
838
+ # 应用Text CFG到已经应用Camera CFG的结果
839
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
840
+
841
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
842
+ # 只使用Text CFG
843
+ noise_pred_cond, moe_loess = pipe.dit(
844
+ new_latents,
845
+ timestep=timestep_tensor,
846
+ cam_emb=camera_embedding,
847
+ modality_inputs=modality_inputs,
848
+ latent_indices=latent_indices,
849
+ clean_latents=clean_latents,
850
+ clean_latent_indices=clean_latent_indices,
851
+ clean_latents_2x=clean_latents_2x,
852
+ clean_latent_2x_indices=clean_latent_2x_indices,
853
+ clean_latents_4x=clean_latents_4x,
854
+ clean_latent_4x_indices=clean_latent_4x_indices,
855
+ **prompt_emb_pos,
856
+ **extra_input
857
+ )
858
+
859
+ noise_pred_uncond, moe_loess= pipe.dit(
860
+ new_latents,
861
+ timestep=timestep_tensor,
862
+ cam_emb=camera_embedding,
863
+ modality_inputs=modality_inputs,
864
+ latent_indices=latent_indices,
865
+ clean_latents=clean_latents,
866
+ clean_latent_indices=clean_latent_indices,
867
+ clean_latents_2x=clean_latents_2x,
868
+ clean_latent_2x_indices=clean_latent_2x_indices,
869
+ clean_latents_4x=clean_latents_4x,
870
+ clean_latent_4x_indices=clean_latent_4x_indices,
871
+ **prompt_emb_neg,
872
+ **extra_input
873
+ )
874
+
875
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
876
+
877
+ else:
878
+ # 标准推理(无CFG)
879
+ noise_pred, moe_loess = pipe.dit(
880
+ new_latents,
881
+ timestep=timestep_tensor,
882
+ cam_emb=camera_embedding,
883
+ modality_inputs=modality_inputs, # MoE模态输入
884
+ latent_indices=latent_indices,
885
+ clean_latents=clean_latents,
886
+ clean_latent_indices=clean_latent_indices,
887
+ clean_latents_2x=clean_latents_2x,
888
+ clean_latent_2x_indices=clean_latent_2x_indices,
889
+ clean_latents_4x=clean_latents_4x,
890
+ clean_latent_4x_indices=clean_latent_4x_indices,
891
+ **prompt_emb_pos,
892
+ **extra_input
893
+ )
894
+
895
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
896
+
897
+ # 更新历史
898
+ new_latents_squeezed = new_latents.squeeze(0)
899
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
900
+
901
+ # 维护滑动窗口
902
+ if history_latents.shape[1] > max_history_frames:
903
+ first_frame = history_latents[:, 0:1, :, :]
904
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
905
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
906
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
907
+
908
+ print(f"更新后history_latents shape: {history_latents.shape}")
909
+
910
+ all_generated_frames.append(new_latents_squeezed)
911
+ total_generated += current_generation
912
+
913
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
914
+
915
+ # 12. 解码和保存
916
+ print("\n🔧 解码生成的视频...")
917
+
918
+ all_generated = torch.cat(all_generated_frames, dim=1)
919
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
920
+
921
+ print(f"最终视频shape: {final_video.shape}")
922
+
923
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
924
+
925
+ print(f"Saving video to {output_path}")
926
+
927
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
928
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
929
+ video_np = (video_np * 255).astype(np.uint8)
930
+
931
+ with imageio.get_writer(output_path, fps=20) as writer:
932
+ for frame in video_np:
933
+ writer.append_data(frame)
934
+
935
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
936
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
937
+ print(f"使用模态: {modality_type}")
938
+
939
+
940
+ def main():
941
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
942
+
943
+ # 基础参数
944
+ parser.add_argument("--condition_pth", type=str,
945
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
946
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
947
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
948
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
949
+ parser.add_argument("--start_frame", type=int, default=0)
950
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
951
+ parser.add_argument("--frames_per_generation", type=int, default=8)
952
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
953
+ parser.add_argument("--max_history_frames", type=int, default=100)
954
+ parser.add_argument("--use_real_poses", default=True)
955
+ parser.add_argument("--dit_path", type=str,
956
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt")
957
+ parser.add_argument("--output_path", type=str,
958
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
959
+ parser.add_argument("--prompt", type=str,
960
+ default="A drone flying scene in a game world ")
961
+ parser.add_argument("--device", type=str, default="cuda")
962
+
963
+ # 模态类型参数
964
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
965
+ help="模态类型:sekai 或 nuscenes 或 openx")
966
+ parser.add_argument("--scene_info_path", type=str, default=None,
967
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
968
+
969
+ # CFG参数
970
+ parser.add_argument("--use_camera_cfg", default=False,
971
+ help="使用Camera CFG")
972
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
973
+ help="Camera guidance scale for CFG")
974
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
975
+ help="Text guidance scale for CFG")
976
+
977
+ # MoE参数
978
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
979
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
980
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
981
+
982
+ args = parser.parse_args()
983
+
984
+ print(f"🔧 MoE FramePack CFG生成设置:")
985
+ print(f"模态类型: {args.modality_type}")
986
+ print(f"Camera CFG: {args.use_camera_cfg}")
987
+ if args.use_camera_cfg:
988
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
989
+ print(f"Text guidance scale: {args.text_guidance_scale}")
990
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
991
+ print(f"DiT{args.dit_path}")
992
+
993
+ # 验证NuScenes参数
994
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
995
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
996
+
997
+ inference_moe_framepack_sliding_window(
998
+ condition_pth_path=args.condition_pth,
999
+ dit_path=args.dit_path,
1000
+ output_path=args.output_path,
1001
+ start_frame=args.start_frame,
1002
+ initial_condition_frames=args.initial_condition_frames,
1003
+ frames_per_generation=args.frames_per_generation,
1004
+ total_frames_to_generate=args.total_frames_to_generate,
1005
+ max_history_frames=args.max_history_frames,
1006
+ device=args.device,
1007
+ prompt=args.prompt,
1008
+ modality_type=args.modality_type,
1009
+ use_real_poses=args.use_real_poses,
1010
+ scene_info_path=args.scene_info_path,
1011
+ # CFG参数
1012
+ use_camera_cfg=args.use_camera_cfg,
1013
+ camera_guidance_scale=args.camera_guidance_scale,
1014
+ text_guidance_scale=args.text_guidance_scale,
1015
+ # MoE参数
1016
+ moe_num_experts=args.moe_num_experts,
1017
+ moe_top_k=args.moe_top_k,
1018
+ moe_hidden_dim=args.moe_hidden_dim
1019
+ )
1020
+
1021
+
1022
+ if __name__ == "__main__":
1023
+ main()
scripts/infer_moe_spatialvid.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+ from scipy.spatial.transform import Rotation as R
14
+
15
+ def compute_relative_pose_matrix(pose1, pose2):
16
+ """
17
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
18
+
19
+ 参数:
20
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
21
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
22
+
23
+ 返回:
24
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
25
+ """
26
+ # 分离平移向量和四元数
27
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
28
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
29
+ t2 = pose2[:3] # 第i+1帧平移
30
+ q2 = pose2[3:] # 第i+1帧四元数
31
+
32
+ # 1. 计算相对旋转矩阵 R_rel
33
+ rot1 = R.from_quat(q1) # 第i帧旋转
34
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
35
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
36
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
37
+
38
+ # 2. 计算相对平移向量 t_rel
39
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
40
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
41
+
42
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
43
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
44
+
45
+ return relative_matrix
46
+
47
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
48
+ """从pth文件加载预编码的视频数据"""
49
+ print(f"Loading encoded video from {pth_path}")
50
+
51
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
52
+ full_latents = encoded_data['latents'] # [C, T, H, W]
53
+
54
+ print(f"Full latents shape: {full_latents.shape}")
55
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
56
+
57
+ if start_frame + num_frames > full_latents.shape[1]:
58
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
59
+
60
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
61
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
62
+
63
+ return condition_latents, encoded_data
64
+
65
+
66
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
67
+ """计算相机B相对于相机A的相对位姿矩阵"""
68
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
69
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
70
+
71
+ if use_torch:
72
+ if not isinstance(pose_a, torch.Tensor):
73
+ pose_a = torch.from_numpy(pose_a).float()
74
+ if not isinstance(pose_b, torch.Tensor):
75
+ pose_b = torch.from_numpy(pose_b).float()
76
+
77
+ pose_a_inv = torch.inverse(pose_a)
78
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
79
+ else:
80
+ if not isinstance(pose_a, np.ndarray):
81
+ pose_a = np.array(pose_a, dtype=np.float32)
82
+ if not isinstance(pose_b, np.ndarray):
83
+ pose_b = np.array(pose_b, dtype=np.float32)
84
+
85
+ pose_a_inv = np.linalg.inv(pose_a)
86
+ relative_pose = np.matmul(pose_b, pose_a_inv)
87
+
88
+ return relative_pose
89
+
90
+
91
+ def replace_dit_model_in_manager():
92
+ """替换DiT模型类为MoE版本"""
93
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
94
+ from diffsynth.configs.model_config import model_loader_configs
95
+
96
+ for i, config in enumerate(model_loader_configs):
97
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
98
+
99
+ if 'wan_video_dit' in model_names:
100
+ new_model_names = []
101
+ new_model_classes = []
102
+
103
+ for name, cls in zip(model_names, model_classes):
104
+ if name == 'wan_video_dit':
105
+ new_model_names.append(name)
106
+ new_model_classes.append(WanModelMoe)
107
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
108
+ else:
109
+ new_model_names.append(name)
110
+ new_model_classes.append(cls)
111
+
112
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
113
+
114
+
115
+ def add_framepack_components(dit_model):
116
+ """添加FramePack相关组件"""
117
+ if not hasattr(dit_model, 'clean_x_embedder'):
118
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
119
+
120
+ class CleanXEmbedder(nn.Module):
121
+ def __init__(self, inner_dim):
122
+ super().__init__()
123
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
124
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
125
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
126
+
127
+ def forward(self, x, scale="1x"):
128
+ if scale == "1x":
129
+ x = x.to(self.proj.weight.dtype)
130
+ return self.proj(x)
131
+ elif scale == "2x":
132
+ x = x.to(self.proj_2x.weight.dtype)
133
+ return self.proj_2x(x)
134
+ elif scale == "4x":
135
+ x = x.to(self.proj_4x.weight.dtype)
136
+ return self.proj_4x(x)
137
+ else:
138
+ raise ValueError(f"Unsupported scale: {scale}")
139
+
140
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
141
+ model_dtype = next(dit_model.parameters()).dtype
142
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
143
+ print("✅ 添加了FramePack的clean_x_embedder组件")
144
+
145
+
146
+ def add_moe_components(dit_model, moe_config):
147
+ """🔧 添加MoE相关组件 - 修正版本"""
148
+ if not hasattr(dit_model, 'moe_config'):
149
+ dit_model.moe_config = moe_config
150
+ print("✅ 添加了MoE配置到模型")
151
+
152
+ # 为每个block动态添加MoE组件
153
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
154
+ unified_dim = moe_config.get("unified_dim", 25)
155
+
156
+ for i, block in enumerate(dit_model.blocks):
157
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
158
+
159
+ # Sekai模态处理器 - 输出unified_dim
160
+ block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
161
+
162
+ # # NuScenes模态处理器 - 输出unified_dim
163
+ # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
164
+
165
+ # MoE网络 - 输入unified_dim,输出dim
166
+ block.moe = MultiModalMoE(
167
+ unified_dim=unified_dim,
168
+ output_dim=dim, # 输出维度匹配transformer block的dim
169
+ num_experts=moe_config.get("num_experts", 4),
170
+ top_k=moe_config.get("top_k", 2)
171
+ )
172
+
173
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
174
+
175
+
176
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
177
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
178
+ time_compression_ratio = 4
179
+
180
+ # 计算FramePack实际需要的camera帧数
181
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
182
+
183
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
184
+ print("🔧 使用真实Sekai camera数据")
185
+ cam_extrinsic = cam_data['extrinsic']
186
+
187
+ # 确保生成足够长的camera序列
188
+ max_needed_frames = max(
189
+ start_frame + current_history_length + new_frames,
190
+ framepack_needed_frames,
191
+ 30
192
+ )
193
+
194
+ print(f"🔧 计算Sekai camera序列长度:")
195
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
196
+ print(f" - FramePack需求: {framepack_needed_frames}")
197
+ print(f" - 最终生成: {max_needed_frames}")
198
+
199
+ relative_poses = []
200
+ for i in range(max_needed_frames):
201
+ # 计算当前帧在原始序列中的位置
202
+ frame_idx = i * time_compression_ratio
203
+ next_frame_idx = frame_idx + time_compression_ratio
204
+
205
+ if next_frame_idx < len(cam_extrinsic):
206
+ cam_prev = cam_extrinsic[frame_idx]
207
+ cam_next = cam_extrinsic[next_frame_idx]
208
+ relative_pose = compute_relative_pose_matrix(cam_prev, cam_next)
209
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
210
+ else:
211
+ # 超出范围,使用零运动
212
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
213
+ relative_poses.append(torch.zeros(3, 4))
214
+
215
+ pose_embedding = torch.stack(relative_poses, dim=0)
216
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
217
+
218
+ # 创建对应长度的mask序列
219
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
220
+ # 从start_frame到current_history_length标记为condition
221
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
222
+ mask[start_frame:condition_end] = 1.0
223
+
224
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
225
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
226
+ return camera_embedding.to(torch.bfloat16)
227
+
228
+ else:
229
+ print("🔧 使用Sekai合成camera数据")
230
+
231
+ max_needed_frames = max(
232
+ start_frame + current_history_length + new_frames,
233
+ framepack_needed_frames,
234
+ 30
235
+ )
236
+
237
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
238
+ relative_poses = []
239
+ for i in range(max_needed_frames):
240
+ # 持续左转运动模式
241
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
242
+ forward_speed = 0.005 # 每帧前进距离
243
+
244
+ pose = np.eye(4, dtype=np.float32)
245
+
246
+ # 旋转矩阵(绕Y轴左转)
247
+ cos_yaw = np.cos(yaw_per_frame)
248
+ sin_yaw = np.sin(yaw_per_frame)
249
+
250
+ pose[0, 0] = cos_yaw
251
+ pose[0, 2] = sin_yaw
252
+ pose[2, 0] = -sin_yaw
253
+ pose[2, 2] = cos_yaw
254
+
255
+ # 平移(在旋转后的局部坐标系中前进)
256
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
257
+
258
+ # 添加轻微的向心运动,模拟圆形轨迹
259
+ radius_drift = 0.002 # 向圆心的轻微漂移
260
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
261
+
262
+ relative_pose = pose[:3, :]
263
+ relative_poses.append(torch.as_tensor(relative_pose))
264
+
265
+ pose_embedding = torch.stack(relative_poses, dim=0)
266
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
267
+
268
+ # 创建对应长度的mask序列
269
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
270
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
271
+ mask[start_frame:condition_end] = 1.0
272
+
273
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
274
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
275
+ return camera_embedding.to(torch.bfloat16)
276
+
277
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
278
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
279
+ time_compression_ratio = 4
280
+
281
+ # 计算FramePack实际需要的camera帧数
282
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
283
+
284
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
285
+ print("🔧 使用OpenX真实camera数据")
286
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
287
+
288
+ # 确保生成足够长的camera序列
289
+ max_needed_frames = max(
290
+ start_frame + current_history_length + new_frames,
291
+ framepack_needed_frames,
292
+ 30
293
+ )
294
+
295
+ print(f"🔧 计算OpenX camera序列长度:")
296
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
297
+ print(f" - FramePack需求: {framepack_needed_frames}")
298
+ print(f" - 最终生成: {max_needed_frames}")
299
+
300
+ relative_poses = []
301
+ for i in range(max_needed_frames):
302
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
303
+ frame_idx = i * time_compression_ratio
304
+ next_frame_idx = frame_idx + time_compression_ratio
305
+
306
+ if next_frame_idx < len(cam_extrinsic):
307
+ cam_prev = cam_extrinsic[frame_idx]
308
+ cam_next = cam_extrinsic[next_frame_idx]
309
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
310
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
311
+ else:
312
+ # 超出范围,使用零运动
313
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
314
+ relative_poses.append(torch.zeros(3, 4))
315
+
316
+ pose_embedding = torch.stack(relative_poses, dim=0)
317
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
318
+
319
+ # 创建对应长度的mask序列
320
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
321
+ # 从start_frame到current_history_length标记为condition
322
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
323
+ mask[start_frame:condition_end] = 1.0
324
+
325
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
326
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
327
+ return camera_embedding.to(torch.bfloat16)
328
+
329
+ else:
330
+ print("🔧 使用OpenX合成camera数据")
331
+
332
+ max_needed_frames = max(
333
+ start_frame + current_history_length + new_frames,
334
+ framepack_needed_frames,
335
+ 30
336
+ )
337
+
338
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
339
+ relative_poses = []
340
+ for i in range(max_needed_frames):
341
+ # OpenX机器人操作运动模式 - 较小的运动幅度
342
+ # 模拟机器人手臂的精细操作运动
343
+ roll_per_frame = 0.02 # 轻微翻滚
344
+ pitch_per_frame = 0.01 # 轻微俯仰
345
+ yaw_per_frame = 0.015 # 轻微偏航
346
+ forward_speed = 0.003 # 较慢的前进速度
347
+
348
+ pose = np.eye(4, dtype=np.float32)
349
+
350
+ # 复合旋转 - 模拟机器人手臂的复杂运动
351
+ # 绕X轴旋转(roll)
352
+ cos_roll = np.cos(roll_per_frame)
353
+ sin_roll = np.sin(roll_per_frame)
354
+ # 绕Y轴旋转(pitch)
355
+ cos_pitch = np.cos(pitch_per_frame)
356
+ sin_pitch = np.sin(pitch_per_frame)
357
+ # 绕Z轴旋转(yaw)
358
+ cos_yaw = np.cos(yaw_per_frame)
359
+ sin_yaw = np.sin(yaw_per_frame)
360
+
361
+ # 简化的复合旋转矩阵(ZYX顺序)
362
+ pose[0, 0] = cos_yaw * cos_pitch
363
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
364
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
365
+ pose[1, 0] = sin_yaw * cos_pitch
366
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
367
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
368
+ pose[2, 0] = -sin_pitch
369
+ pose[2, 1] = cos_pitch * sin_roll
370
+ pose[2, 2] = cos_pitch * cos_roll
371
+
372
+ # 平移 - 模拟机器人操作的精细移动
373
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
374
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
375
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
376
+
377
+ relative_pose = pose[:3, :]
378
+ relative_poses.append(torch.as_tensor(relative_pose))
379
+
380
+ pose_embedding = torch.stack(relative_poses, dim=0)
381
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
382
+
383
+ # 创建对应长度的mask序列
384
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
385
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
386
+ mask[start_frame:condition_end] = 1.0
387
+
388
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
389
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
390
+ return camera_embedding.to(torch.bfloat16)
391
+
392
+
393
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
394
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
395
+ time_compression_ratio = 4
396
+
397
+ # 计算FramePack实际需要的camera帧数
398
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
399
+
400
+ if scene_info is not None and 'keyframe_poses' in scene_info:
401
+ print("🔧 使用NuScenes真实pose数据")
402
+ keyframe_poses = scene_info['keyframe_poses']
403
+
404
+ if len(keyframe_poses) == 0:
405
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
406
+ max_needed_frames = max(framepack_needed_frames, 30)
407
+
408
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
409
+
410
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
411
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
412
+ mask[start_frame:condition_end] = 1.0
413
+
414
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
415
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
416
+ return camera_embedding.to(torch.bfloat16)
417
+
418
+ # 使用第一个pose作为参考
419
+ reference_pose = keyframe_poses[0]
420
+
421
+ max_needed_frames = max(framepack_needed_frames, 30)
422
+
423
+ pose_vecs = []
424
+ for i in range(max_needed_frames):
425
+ if i < len(keyframe_poses):
426
+ current_pose = keyframe_poses[i]
427
+
428
+ # 计算相对位移
429
+ translation = torch.tensor(
430
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
431
+ dtype=torch.float32
432
+ )
433
+
434
+ # 计算相对旋转(简化版本)
435
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
436
+
437
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
438
+ else:
439
+ # 超出范围,使用零pose
440
+ pose_vec = torch.cat([
441
+ torch.zeros(3, dtype=torch.float32),
442
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
443
+ ], dim=0) # [7D]
444
+
445
+ pose_vecs.append(pose_vec)
446
+
447
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
448
+
449
+ # 创建mask
450
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
451
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
452
+ mask[start_frame:condition_end] = 1.0
453
+
454
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
455
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
456
+ return camera_embedding.to(torch.bfloat16)
457
+
458
+ else:
459
+ print("🔧 使用NuScenes合成pose数据")
460
+ max_needed_frames = max(framepack_needed_frames, 30)
461
+
462
+ # 创建合成运动序列
463
+ pose_vecs = []
464
+ for i in range(max_needed_frames):
465
+ # 简单的前进运动
466
+ translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
467
+ rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
468
+
469
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
470
+ pose_vecs.append(pose_vec)
471
+
472
+ pose_sequence = torch.stack(pose_vecs, dim=0)
473
+
474
+ # 创建mask
475
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
476
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
477
+ mask[start_frame:condition_end] = 1.0
478
+
479
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
480
+ print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
481
+ return camera_embedding.to(torch.bfloat16)
482
+
483
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
484
+ """FramePack滑动窗口机制 - MoE版本"""
485
+ # history_latents: [C, T, H, W] 当前的历史latents
486
+ C, T, H, W = history_latents.shape
487
+
488
+ # 固定索引结构(这决定了需要的camera帧数)
489
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
490
+ indices = torch.arange(0, total_indices_length)
491
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
492
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
493
+ indices.split(split_sizes, dim=0)
494
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
495
+
496
+ # 检查camera长度是否足够
497
+ if camera_embedding_full.shape[0] < total_indices_length:
498
+ shortage = total_indices_length - camera_embedding_full.shape[0]
499
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
500
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
501
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
502
+
503
+ # 从完整camera序列中选取对应部分
504
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
505
+
506
+ # 根据当前history length重新设置mask
507
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
508
+
509
+ # 设置condition mask:前19帧根据实际历史长度决定
510
+ if T > 0:
511
+ available_frames = min(T, 19)
512
+ start_pos = 19 - available_frames
513
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
514
+
515
+ print(f"🔧 MoE Camera mask更新:")
516
+ print(f" - 历史帧数: {T}")
517
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
518
+ print(f" - 模态类型: {modality_type}")
519
+
520
+ # 处理latents
521
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
522
+
523
+ if T > 0:
524
+ available_frames = min(T, 19)
525
+ start_pos = 19 - available_frames
526
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
527
+
528
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
529
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
530
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
531
+
532
+ if T > 0:
533
+ start_latent = history_latents[:, 0:1, :, :]
534
+ else:
535
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
536
+
537
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
538
+
539
+ return {
540
+ 'latent_indices': latent_indices,
541
+ 'clean_latents': clean_latents,
542
+ 'clean_latents_2x': clean_latents_2x,
543
+ 'clean_latents_4x': clean_latents_4x,
544
+ 'clean_latent_indices': clean_latent_indices,
545
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
546
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
547
+ 'camera_embedding': combined_camera,
548
+ 'modality_type': modality_type, # 新增模态类型信息
549
+ 'current_length': T,
550
+ 'next_length': T + target_frames_to_generate
551
+ }
552
+
553
+
554
+ def inference_moe_framepack_sliding_window(
555
+ condition_pth_path,
556
+ dit_path,
557
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
558
+ start_frame=0,
559
+ initial_condition_frames=8,
560
+ frames_per_generation=4,
561
+ total_frames_to_generate=32,
562
+ max_history_frames=49,
563
+ device="cuda",
564
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
565
+ modality_type="sekai", # "sekai" 或 "nuscenes"
566
+ use_real_poses=True,
567
+ scene_info_path=None, # 对于NuScenes数据集
568
+ # CFG参数
569
+ use_camera_cfg=True,
570
+ camera_guidance_scale=2.0,
571
+ text_guidance_scale=1.0,
572
+ # MoE参数
573
+ moe_num_experts=4,
574
+ moe_top_k=2,
575
+ moe_hidden_dim=None
576
+ ):
577
+ """
578
+ MoE FramePack滑动窗口视频生成 - 支持多模态
579
+ """
580
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
581
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
582
+ print(f"模态类型: {modality_type}")
583
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
584
+ print(f"Text guidance scale: {text_guidance_scale}")
585
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
586
+
587
+ # 1. 模型初始化
588
+ replace_dit_model_in_manager()
589
+
590
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
591
+ model_manager.load_models([
592
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
593
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
594
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
595
+ ])
596
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
597
+
598
+ # 2. 添加传统camera编码器(兼容性)
599
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
600
+ for block in pipe.dit.blocks:
601
+ block.cam_encoder = nn.Linear(13, dim)
602
+ block.projector = nn.Linear(dim, dim)
603
+ block.cam_encoder.weight.data.zero_()
604
+ block.cam_encoder.bias.data.zero_()
605
+ block.projector.weight = nn.Parameter(torch.eye(dim))
606
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
607
+
608
+ # 3. 添加FramePack组件
609
+ add_framepack_components(pipe.dit)
610
+
611
+ # 4. 添加MoE组件
612
+ moe_config = {
613
+ "num_experts": moe_num_experts,
614
+ "top_k": moe_top_k,
615
+ "hidden_dim": moe_hidden_dim or dim * 2,
616
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
617
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
618
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
619
+ }
620
+ add_moe_components(pipe.dit, moe_config)
621
+
622
+ # 5. 加载训练好的权重
623
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
624
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
625
+ pipe = pipe.to(device)
626
+ model_dtype = next(pipe.dit.parameters()).dtype
627
+
628
+ if hasattr(pipe.dit, 'clean_x_embedder'):
629
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
630
+
631
+ pipe.scheduler.set_timesteps(50)
632
+
633
+ # 6. 加载初始条件
634
+ print("Loading initial condition frames...")
635
+ initial_latents, encoded_data = load_encoded_video_from_pth(
636
+ condition_pth_path,
637
+ start_frame=start_frame,
638
+ num_frames=initial_condition_frames
639
+ )
640
+
641
+ # 空间裁剪
642
+ target_height, target_width = 60, 104
643
+ C, T, H, W = initial_latents.shape
644
+
645
+ if H > target_height or W > target_width:
646
+ h_start = (H - target_height) // 2
647
+ w_start = (W - target_width) // 2
648
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
649
+ H, W = target_height, target_width
650
+
651
+ history_latents = initial_latents.to(device, dtype=model_dtype)
652
+
653
+ print(f"初始history_latents shape: {history_latents.shape}")
654
+
655
+ # 7. 编码prompt - 支持CFG
656
+ if text_guidance_scale > 1.0:
657
+ prompt_emb_pos = pipe.encode_prompt(prompt)
658
+ prompt_emb_neg = pipe.encode_prompt("")
659
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
660
+ else:
661
+ prompt_emb_pos = pipe.encode_prompt(prompt)
662
+ prompt_emb_neg = None
663
+ print("不使用Text CFG")
664
+
665
+ # 8. 加载场景信息(对于NuScenes)
666
+ scene_info = None
667
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
668
+ with open(scene_info_path, 'r') as f:
669
+ scene_info = json.load(f)
670
+ print(f"加载NuScenes场景信息: {scene_info_path}")
671
+
672
+ # 9. 预生成完整的camera embedding序列
673
+ if modality_type == "sekai":
674
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
675
+ encoded_data.get('cam_emb', None),
676
+ 0,
677
+ max_history_frames,
678
+ 0,
679
+ 0,
680
+ use_real_poses=use_real_poses
681
+ ).to(device, dtype=model_dtype)
682
+ elif modality_type == "nuscenes":
683
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
684
+ scene_info,
685
+ 0,
686
+ max_history_frames,
687
+ 0
688
+ ).to(device, dtype=model_dtype)
689
+ elif modality_type == "openx":
690
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
691
+ encoded_data,
692
+ 0,
693
+ max_history_frames,
694
+ 0,
695
+ use_real_poses=use_real_poses
696
+ ).to(device, dtype=model_dtype)
697
+ else:
698
+ raise ValueError(f"不支持的模态类型: {modality_type}")
699
+
700
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
701
+
702
+ # 10. 为Camera CFG创建无条件的camera embedding
703
+ if use_camera_cfg:
704
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
705
+ print(f"创建无条件camera embedding用于CFG")
706
+
707
+ # 11. 滑动窗口生成循环
708
+ total_generated = 0
709
+ all_generated_frames = []
710
+
711
+ while total_generated < total_frames_to_generate:
712
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
713
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
714
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
715
+
716
+ # FramePack数据准备 - MoE版本
717
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
718
+ history_latents,
719
+ current_generation,
720
+ camera_embedding_full,
721
+ start_frame,
722
+ modality_type,
723
+ max_history_frames
724
+ )
725
+
726
+ # 准备输入
727
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
728
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
729
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
730
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
731
+
732
+ # 准备modality_inputs
733
+ modality_inputs = {modality_type: camera_embedding}
734
+
735
+ # 为CFG准备无条件camera embedding
736
+ if use_camera_cfg:
737
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
738
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
739
+
740
+ # 索引处理
741
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
742
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
743
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
744
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
745
+
746
+ # 初始化要生成的latents
747
+ new_latents = torch.randn(
748
+ 1, C, current_generation, H, W,
749
+ device=device, dtype=model_dtype
750
+ )
751
+
752
+ extra_input = pipe.prepare_extra_input(new_latents)
753
+
754
+ print(f"Camera embedding shape: {camera_embedding.shape}")
755
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
756
+
757
+ # 去噪循环 - 支持CFG
758
+ timesteps = pipe.scheduler.timesteps
759
+
760
+ for i, timestep in enumerate(timesteps):
761
+ if i % 10 == 0:
762
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
763
+
764
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
765
+
766
+ with torch.no_grad():
767
+ # CFG推理
768
+ if use_camera_cfg and camera_guidance_scale > 1.0:
769
+ # 条件预测(有camera)
770
+ noise_pred_cond, moe_loss = pipe.dit(
771
+ new_latents,
772
+ timestep=timestep_tensor,
773
+ cam_emb=camera_embedding,
774
+ modality_inputs=modality_inputs, # MoE模态输入
775
+ latent_indices=latent_indices,
776
+ clean_latents=clean_latents,
777
+ clean_latent_indices=clean_latent_indices,
778
+ clean_latents_2x=clean_latents_2x,
779
+ clean_latent_2x_indices=clean_latent_2x_indices,
780
+ clean_latents_4x=clean_latents_4x,
781
+ clean_latent_4x_indices=clean_latent_4x_indices,
782
+ **prompt_emb_pos,
783
+ **extra_input
784
+ )
785
+
786
+ # 无条件预测(无camera)
787
+ noise_pred_uncond, moe_loss = pipe.dit(
788
+ new_latents,
789
+ timestep=timestep_tensor,
790
+ cam_emb=camera_embedding_uncond_batch,
791
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
792
+ latent_indices=latent_indices,
793
+ clean_latents=clean_latents,
794
+ clean_latent_indices=clean_latent_indices,
795
+ clean_latents_2x=clean_latents_2x,
796
+ clean_latent_2x_indices=clean_latent_2x_indices,
797
+ clean_latents_4x=clean_latents_4x,
798
+ clean_latent_4x_indices=clean_latent_4x_indices,
799
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
800
+ **extra_input
801
+ )
802
+
803
+ # Camera CFG
804
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
805
+
806
+ # 如果同时使用Text CFG
807
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
808
+ noise_pred_text_uncond, moe_loss = pipe.dit(
809
+ new_latents,
810
+ timestep=timestep_tensor,
811
+ cam_emb=camera_embedding,
812
+ modality_inputs=modality_inputs,
813
+ latent_indices=latent_indices,
814
+ clean_latents=clean_latents,
815
+ clean_latent_indices=clean_latent_indices,
816
+ clean_latents_2x=clean_latents_2x,
817
+ clean_latent_2x_indices=clean_latent_2x_indices,
818
+ clean_latents_4x=clean_latents_4x,
819
+ clean_latent_4x_indices=clean_latent_4x_indices,
820
+ **prompt_emb_neg,
821
+ **extra_input
822
+ )
823
+
824
+ # 应用Text CFG到已经应用Camera CFG的结果
825
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
826
+
827
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
828
+ # 只使用Text CFG
829
+ noise_pred_cond, moe_loss = pipe.dit(
830
+ new_latents,
831
+ timestep=timestep_tensor,
832
+ cam_emb=camera_embedding,
833
+ modality_inputs=modality_inputs,
834
+ latent_indices=latent_indices,
835
+ clean_latents=clean_latents,
836
+ clean_latent_indices=clean_latent_indices,
837
+ clean_latents_2x=clean_latents_2x,
838
+ clean_latent_2x_indices=clean_latent_2x_indices,
839
+ clean_latents_4x=clean_latents_4x,
840
+ clean_latent_4x_indices=clean_latent_4x_indices,
841
+ **prompt_emb_pos,
842
+ **extra_input
843
+ )
844
+
845
+ noise_pred_uncond, moe_loss = pipe.dit(
846
+ new_latents,
847
+ timestep=timestep_tensor,
848
+ cam_emb=camera_embedding,
849
+ modality_inputs=modality_inputs,
850
+ latent_indices=latent_indices,
851
+ clean_latents=clean_latents,
852
+ clean_latent_indices=clean_latent_indices,
853
+ clean_latents_2x=clean_latents_2x,
854
+ clean_latent_2x_indices=clean_latent_2x_indices,
855
+ clean_latents_4x=clean_latents_4x,
856
+ clean_latent_4x_indices=clean_latent_4x_indices,
857
+ **prompt_emb_neg,
858
+ **extra_input
859
+ )
860
+
861
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
862
+
863
+ else:
864
+ # 标准推理(无CFG)
865
+ noise_pred, moe_loss = pipe.dit(
866
+ new_latents,
867
+ timestep=timestep_tensor,
868
+ cam_emb=camera_embedding,
869
+ modality_inputs=modality_inputs, # MoE模态输入
870
+ latent_indices=latent_indices,
871
+ clean_latents=clean_latents,
872
+ clean_latent_indices=clean_latent_indices,
873
+ clean_latents_2x=clean_latents_2x,
874
+ clean_latent_2x_indices=clean_latent_2x_indices,
875
+ clean_latents_4x=clean_latents_4x,
876
+ clean_latent_4x_indices=clean_latent_4x_indices,
877
+ **prompt_emb_pos,
878
+ **extra_input
879
+ )
880
+
881
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
882
+
883
+ # 更新历史
884
+ new_latents_squeezed = new_latents.squeeze(0)
885
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
886
+
887
+ # 维护滑动窗口
888
+ if history_latents.shape[1] > max_history_frames:
889
+ first_frame = history_latents[:, 0:1, :, :]
890
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
891
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
892
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
893
+
894
+ print(f"更新后history_latents shape: {history_latents.shape}")
895
+
896
+ all_generated_frames.append(new_latents_squeezed)
897
+ total_generated += current_generation
898
+
899
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
900
+
901
+ # 12. 解码和保存
902
+ print("\n🔧 解码生成的视频...")
903
+
904
+ all_generated = torch.cat(all_generated_frames, dim=1)
905
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
906
+
907
+ print(f"最终视频shape: {final_video.shape}")
908
+
909
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
910
+
911
+ print(f"Saving video to {output_path}")
912
+
913
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
914
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
915
+ video_np = (video_np * 255).astype(np.uint8)
916
+
917
+ with imageio.get_writer(output_path, fps=20) as writer:
918
+ for frame in video_np:
919
+ writer.append_data(frame)
920
+
921
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
922
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
923
+ print(f"使用模态: {modality_type}")
924
+
925
+
926
+ def main():
927
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
928
+
929
+ # 基础参数
930
+ parser.add_argument("--condition_pth", type=str,
931
+ #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
932
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
933
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
934
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
935
+ parser.add_argument("--start_frame", type=int, default=0)
936
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
937
+ parser.add_argument("--frames_per_generation", type=int, default=8)
938
+ parser.add_argument("--total_frames_to_generate", type=int, default=8)
939
+ parser.add_argument("--max_history_frames", type=int, default=100)
940
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
941
+ parser.add_argument("--dit_path", type=str,
942
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid/step250_moe.ckpt")
943
+ parser.add_argument("--output_path", type=str,
944
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
945
+ parser.add_argument("--prompt", type=str,
946
+ default="A man enter the room")
947
+ parser.add_argument("--device", type=str, default="cuda")
948
+
949
+ # 模态类型参数
950
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
951
+ help="模态类型:sekai 或 nuscenes 或 openx")
952
+ parser.add_argument("--scene_info_path", type=str, default=None,
953
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
954
+
955
+ # CFG参数
956
+ parser.add_argument("--use_camera_cfg", default=True,
957
+ help="使用Camera CFG")
958
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
959
+ help="Camera guidance scale for CFG")
960
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
961
+ help="Text guidance scale for CFG")
962
+
963
+ # MoE参数
964
+ parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
965
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
966
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
967
+
968
+ args = parser.parse_args()
969
+
970
+ print(f"🔧 MoE FramePack CFG生成设置:")
971
+ print(f"模态类型: {args.modality_type}")
972
+ print(f"Camera CFG: {args.use_camera_cfg}")
973
+ if args.use_camera_cfg:
974
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
975
+ print(f"Text guidance scale: {args.text_guidance_scale}")
976
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
977
+
978
+ # 验证NuScenes参数
979
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
980
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
981
+
982
+ inference_moe_framepack_sliding_window(
983
+ condition_pth_path=args.condition_pth,
984
+ dit_path=args.dit_path,
985
+ output_path=args.output_path,
986
+ start_frame=args.start_frame,
987
+ initial_condition_frames=args.initial_condition_frames,
988
+ frames_per_generation=args.frames_per_generation,
989
+ total_frames_to_generate=args.total_frames_to_generate,
990
+ max_history_frames=args.max_history_frames,
991
+ device=args.device,
992
+ prompt=args.prompt,
993
+ modality_type=args.modality_type,
994
+ use_real_poses=args.use_real_poses,
995
+ scene_info_path=args.scene_info_path,
996
+ # CFG参数
997
+ use_camera_cfg=args.use_camera_cfg,
998
+ camera_guidance_scale=args.camera_guidance_scale,
999
+ text_guidance_scale=args.text_guidance_scale,
1000
+ # MoE参数
1001
+ moe_num_experts=args.moe_num_experts,
1002
+ moe_top_k=args.moe_top_k,
1003
+ moe_hidden_dim=args.moe_hidden_dim
1004
+ )
1005
+
1006
+
1007
+ if __name__ == "__main__":
1008
+ main()
scripts/infer_moe_test.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+
15
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
16
+ """从pth文件加载预编码的视频数据"""
17
+ print(f"Loading encoded video from {pth_path}")
18
+
19
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
20
+ full_latents = encoded_data['latents'] # [C, T, H, W]
21
+
22
+ print(f"Full latents shape: {full_latents.shape}")
23
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
24
+
25
+ if start_frame + num_frames > full_latents.shape[1]:
26
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
27
+
28
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
29
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
30
+
31
+ return condition_latents, encoded_data
32
+
33
+
34
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
35
+ """计算相机B相对于相机A的相对位姿矩阵"""
36
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
37
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
38
+
39
+ if use_torch:
40
+ if not isinstance(pose_a, torch.Tensor):
41
+ pose_a = torch.from_numpy(pose_a).float()
42
+ if not isinstance(pose_b, torch.Tensor):
43
+ pose_b = torch.from_numpy(pose_b).float()
44
+
45
+ pose_a_inv = torch.inverse(pose_a)
46
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
47
+ else:
48
+ if not isinstance(pose_a, np.ndarray):
49
+ pose_a = np.array(pose_a, dtype=np.float32)
50
+ if not isinstance(pose_b, np.ndarray):
51
+ pose_b = np.array(pose_b, dtype=np.float32)
52
+
53
+ pose_a_inv = np.linalg.inv(pose_a)
54
+ relative_pose = np.matmul(pose_b, pose_a_inv)
55
+
56
+ return relative_pose
57
+
58
+
59
+ def replace_dit_model_in_manager():
60
+ """替换DiT模型类为MoE版本"""
61
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
62
+ from diffsynth.configs.model_config import model_loader_configs
63
+
64
+ for i, config in enumerate(model_loader_configs):
65
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
66
+
67
+ if 'wan_video_dit' in model_names:
68
+ new_model_names = []
69
+ new_model_classes = []
70
+
71
+ for name, cls in zip(model_names, model_classes):
72
+ if name == 'wan_video_dit':
73
+ new_model_names.append(name)
74
+ new_model_classes.append(WanModelMoe)
75
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
76
+ else:
77
+ new_model_names.append(name)
78
+ new_model_classes.append(cls)
79
+
80
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
81
+
82
+
83
+ def add_framepack_components(dit_model):
84
+ """添加FramePack相关组件"""
85
+ if not hasattr(dit_model, 'clean_x_embedder'):
86
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
87
+
88
+ class CleanXEmbedder(nn.Module):
89
+ def __init__(self, inner_dim):
90
+ super().__init__()
91
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
92
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
93
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
94
+
95
+ def forward(self, x, scale="1x"):
96
+ if scale == "1x":
97
+ x = x.to(self.proj.weight.dtype)
98
+ return self.proj(x)
99
+ elif scale == "2x":
100
+ x = x.to(self.proj_2x.weight.dtype)
101
+ return self.proj_2x(x)
102
+ elif scale == "4x":
103
+ x = x.to(self.proj_4x.weight.dtype)
104
+ return self.proj_4x(x)
105
+ else:
106
+ raise ValueError(f"Unsupported scale: {scale}")
107
+
108
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
109
+ model_dtype = next(dit_model.parameters()).dtype
110
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
111
+ print("✅ 添加了FramePack的clean_x_embedder组件")
112
+
113
+
114
+ def add_moe_components(dit_model, moe_config):
115
+ """🔧 添加MoE相关组件 - 修正版本"""
116
+ if not hasattr(dit_model, 'moe_config'):
117
+ dit_model.moe_config = moe_config
118
+ print("✅ 添加了MoE配置到模型")
119
+
120
+ # 为每个block动态添加MoE组件
121
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
122
+ unified_dim = moe_config.get("unified_dim", 25)
123
+
124
+ for i, block in enumerate(dit_model.blocks):
125
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
126
+
127
+ # Sekai模态处理器 - 输出unified_dim
128
+ block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
129
+
130
+ # # NuScenes模态处理器 - 输出unified_dim
131
+ # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
132
+
133
+ # MoE网络 - 输入unified_dim,输出dim
134
+ block.moe = MultiModalMoE(
135
+ unified_dim=unified_dim,
136
+ output_dim=dim, # 输出维度匹配transformer block的dim
137
+ num_experts=moe_config.get("num_experts", 4),
138
+ top_k=moe_config.get("top_k", 2)
139
+ )
140
+
141
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
142
+
143
+
144
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
145
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
146
+ time_compression_ratio = 4
147
+
148
+ # 计算FramePack实际需要的camera帧数
149
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
150
+
151
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
152
+ print("🔧 使用真实Sekai camera数据")
153
+ cam_extrinsic = cam_data['extrinsic']
154
+
155
+ # 确保生成足够长的camera序列
156
+ max_needed_frames = max(
157
+ start_frame + current_history_length + new_frames,
158
+ framepack_needed_frames,
159
+ 30
160
+ )
161
+
162
+ print(f"🔧 计算Sekai camera序列长度:")
163
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
164
+ print(f" - FramePack需求: {framepack_needed_frames}")
165
+ print(f" - 最终生成: {max_needed_frames}")
166
+
167
+ relative_poses = []
168
+ for i in range(max_needed_frames):
169
+ # 计算当前帧在原始序列中的位置
170
+ frame_idx = i * time_compression_ratio
171
+ next_frame_idx = frame_idx + time_compression_ratio
172
+
173
+ if next_frame_idx < len(cam_extrinsic):
174
+ cam_prev = cam_extrinsic[frame_idx]
175
+ cam_next = cam_extrinsic[next_frame_idx]
176
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
177
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
178
+ else:
179
+ # 超出范围,使用零运动
180
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
181
+ relative_poses.append(torch.zeros(3, 4))
182
+
183
+ pose_embedding = torch.stack(relative_poses, dim=0)
184
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
185
+
186
+ # 创建对应长度的mask序列
187
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
188
+ # 从start_frame到current_history_length标记为condition
189
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
190
+ mask[start_frame:condition_end] = 1.0
191
+
192
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
193
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
194
+ return camera_embedding.to(torch.bfloat16)
195
+
196
+ else:
197
+ print("🔧 使用Sekai合成camera数据")
198
+
199
+ max_needed_frames = max(
200
+ start_frame + current_history_length + new_frames,
201
+ framepack_needed_frames,
202
+ 30
203
+ )
204
+
205
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
206
+ relative_poses = []
207
+ for i in range(max_needed_frames):
208
+ # 持续左转运动模式
209
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
210
+ forward_speed = 0.005 # 每帧前进距离
211
+
212
+ pose = np.eye(4, dtype=np.float32)
213
+
214
+ # 旋转矩阵(绕Y轴左转)
215
+ cos_yaw = np.cos(yaw_per_frame)
216
+ sin_yaw = np.sin(yaw_per_frame)
217
+
218
+ pose[0, 0] = cos_yaw
219
+ pose[0, 2] = sin_yaw
220
+ pose[2, 0] = -sin_yaw
221
+ pose[2, 2] = cos_yaw
222
+
223
+ # 平移(在旋转后的局部坐标系中前进)
224
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
225
+
226
+ # 添加轻微的向心运动,模拟圆形轨迹
227
+ radius_drift = 0.002 # 向圆心的轻微漂移
228
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
229
+
230
+ relative_pose = pose[:3, :]
231
+ relative_poses.append(torch.as_tensor(relative_pose))
232
+
233
+ pose_embedding = torch.stack(relative_poses, dim=0)
234
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
235
+
236
+ # 创建对应长度的mask序列
237
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
238
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
239
+ mask[start_frame:condition_end] = 1.0
240
+
241
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
242
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
243
+ return camera_embedding.to(torch.bfloat16)
244
+
245
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
246
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
247
+ time_compression_ratio = 4
248
+
249
+ # 计算FramePack实际需要的camera帧数
250
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
251
+
252
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
253
+ print("🔧 使用OpenX真实camera数据")
254
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
255
+
256
+ # 确保生成足够长的camera序列
257
+ max_needed_frames = max(
258
+ start_frame + current_history_length + new_frames,
259
+ framepack_needed_frames,
260
+ 30
261
+ )
262
+
263
+ print(f"🔧 计算OpenX camera序列长度:")
264
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
265
+ print(f" - FramePack需求: {framepack_needed_frames}")
266
+ print(f" - 最终生成: {max_needed_frames}")
267
+
268
+ relative_poses = []
269
+ for i in range(max_needed_frames):
270
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
271
+ frame_idx = i * time_compression_ratio
272
+ next_frame_idx = frame_idx + time_compression_ratio
273
+
274
+ if next_frame_idx < len(cam_extrinsic):
275
+ cam_prev = cam_extrinsic[frame_idx]
276
+ cam_next = cam_extrinsic[next_frame_idx]
277
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
278
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
279
+ else:
280
+ # 超出范围,使用零运动
281
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
282
+ relative_poses.append(torch.zeros(3, 4))
283
+
284
+ pose_embedding = torch.stack(relative_poses, dim=0)
285
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
286
+
287
+ # 创建对应长度的mask序列
288
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
289
+ # 从start_frame到current_history_length标记为condition
290
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
291
+ mask[start_frame:condition_end] = 1.0
292
+
293
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
294
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
295
+ return camera_embedding.to(torch.bfloat16)
296
+
297
+ else:
298
+ print("🔧 使用OpenX合成camera数据")
299
+
300
+ max_needed_frames = max(
301
+ start_frame + current_history_length + new_frames,
302
+ framepack_needed_frames,
303
+ 30
304
+ )
305
+
306
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
307
+ relative_poses = []
308
+ for i in range(max_needed_frames):
309
+ # OpenX机器人操作运动模式 - 较小的运动幅度
310
+ # 模拟机器人手臂的精细操作运动
311
+ roll_per_frame = 0.02 # 轻微翻滚
312
+ pitch_per_frame = 0.01 # 轻微俯仰
313
+ yaw_per_frame = 0.015 # 轻微偏航
314
+ forward_speed = 0.003 # 较慢的前进速度
315
+
316
+ pose = np.eye(4, dtype=np.float32)
317
+
318
+ # 复合旋转 - 模拟机器人手臂的复杂运动
319
+ # 绕X轴旋转(roll)
320
+ cos_roll = np.cos(roll_per_frame)
321
+ sin_roll = np.sin(roll_per_frame)
322
+ # 绕Y轴旋转(pitch)
323
+ cos_pitch = np.cos(pitch_per_frame)
324
+ sin_pitch = np.sin(pitch_per_frame)
325
+ # 绕Z轴旋转(yaw)
326
+ cos_yaw = np.cos(yaw_per_frame)
327
+ sin_yaw = np.sin(yaw_per_frame)
328
+
329
+ # 简化的复合旋转矩阵(ZYX顺序)
330
+ pose[0, 0] = cos_yaw * cos_pitch
331
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
332
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
333
+ pose[1, 0] = sin_yaw * cos_pitch
334
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
335
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
336
+ pose[2, 0] = -sin_pitch
337
+ pose[2, 1] = cos_pitch * sin_roll
338
+ pose[2, 2] = cos_pitch * cos_roll
339
+
340
+ # 平移 - 模拟机器人操作的精细移动
341
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
342
+ pose[1, 3] = forward_speed * 0.3 # Y��向轻微移动
343
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
344
+
345
+ relative_pose = pose[:3, :]
346
+ relative_poses.append(torch.as_tensor(relative_pose))
347
+
348
+ pose_embedding = torch.stack(relative_poses, dim=0)
349
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
350
+
351
+ # 创建对应长度的mask序列
352
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
353
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
354
+ mask[start_frame:condition_end] = 1.0
355
+
356
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
357
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
358
+ return camera_embedding.to(torch.bfloat16)
359
+
360
+
361
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
362
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
363
+ time_compression_ratio = 4
364
+
365
+ # 计算FramePack实际需要的camera帧数
366
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
367
+
368
+ if scene_info is not None and 'keyframe_poses' in scene_info:
369
+ print("🔧 使用NuScenes真实pose数据")
370
+ keyframe_poses = scene_info['keyframe_poses']
371
+
372
+ if len(keyframe_poses) == 0:
373
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
374
+ max_needed_frames = max(framepack_needed_frames, 30)
375
+
376
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
377
+
378
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
379
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
380
+ mask[start_frame:condition_end] = 1.0
381
+
382
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
383
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
384
+ return camera_embedding.to(torch.bfloat16)
385
+
386
+ # 使用第一个pose作为参考
387
+ reference_pose = keyframe_poses[0]
388
+
389
+ max_needed_frames = max(framepack_needed_frames, 30)
390
+
391
+ pose_vecs = []
392
+ for i in range(max_needed_frames):
393
+ if i < len(keyframe_poses):
394
+ current_pose = keyframe_poses[i]
395
+
396
+ # 计算相对位移
397
+ translation = torch.tensor(
398
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
399
+ dtype=torch.float32
400
+ )
401
+
402
+ # 计算相对旋转(简化版本)
403
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
404
+
405
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
406
+ else:
407
+ # 超出范围,使用零pose
408
+ pose_vec = torch.cat([
409
+ torch.zeros(3, dtype=torch.float32),
410
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
411
+ ], dim=0) # [7D]
412
+
413
+ pose_vecs.append(pose_vec)
414
+
415
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
416
+
417
+ # 创建mask
418
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
419
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
420
+ mask[start_frame:condition_end] = 1.0
421
+
422
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
423
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
424
+ return camera_embedding.to(torch.bfloat16)
425
+
426
+ else:
427
+ print("🔧 使用NuScenes合成pose数据")
428
+ max_needed_frames = max(framepack_needed_frames, 30)
429
+
430
+ # 创建合成运动序列
431
+ pose_vecs = []
432
+ for i in range(max_needed_frames):
433
+ # 简单的前进运动
434
+ translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
435
+ rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
436
+
437
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
438
+ pose_vecs.append(pose_vec)
439
+
440
+ pose_sequence = torch.stack(pose_vecs, dim=0)
441
+
442
+ # 创建mask
443
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
444
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
445
+ mask[start_frame:condition_end] = 1.0
446
+
447
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
448
+ print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
449
+ return camera_embedding.to(torch.bfloat16)
450
+
451
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
452
+ """FramePack滑动窗口机制 - MoE版本"""
453
+ # history_latents: [C, T, H, W] 当前的历史latents
454
+ C, T, H, W = history_latents.shape
455
+
456
+ # 固定索引结构(这决定了需要的camera帧数)
457
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
458
+ indices = torch.arange(0, total_indices_length)
459
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
460
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
461
+ indices.split(split_sizes, dim=0)
462
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
463
+
464
+ # 检查camera长度是否足够
465
+ if camera_embedding_full.shape[0] < total_indices_length:
466
+ shortage = total_indices_length - camera_embedding_full.shape[0]
467
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
468
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
469
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
470
+
471
+ # 从完整camera序列中选取对应部分
472
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
473
+
474
+ # 根据当前history length重新设置mask
475
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
476
+
477
+ # 设置condition mask:前19帧根据实际历史长度决定
478
+ if T > 0:
479
+ available_frames = min(T, 19)
480
+ start_pos = 19 - available_frames
481
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
482
+
483
+ print(f"🔧 MoE Camera mask更新:")
484
+ print(f" - 历史帧数: {T}")
485
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
486
+ print(f" - 模态类型: {modality_type}")
487
+
488
+ # 处理latents
489
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
490
+
491
+ if T > 0:
492
+ available_frames = min(T, 19)
493
+ start_pos = 19 - available_frames
494
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
495
+
496
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
497
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
498
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
499
+
500
+ if T > 0:
501
+ start_latent = history_latents[:, 0:1, :, :]
502
+ else:
503
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
504
+
505
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
506
+
507
+ return {
508
+ 'latent_indices': latent_indices,
509
+ 'clean_latents': clean_latents,
510
+ 'clean_latents_2x': clean_latents_2x,
511
+ 'clean_latents_4x': clean_latents_4x,
512
+ 'clean_latent_indices': clean_latent_indices,
513
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
514
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
515
+ 'camera_embedding': combined_camera,
516
+ 'modality_type': modality_type, # 新增模态类型信息
517
+ 'current_length': T,
518
+ 'next_length': T + target_frames_to_generate
519
+ }
520
+
521
+
522
+ def inference_moe_framepack_sliding_window(
523
+ condition_pth_path,
524
+ dit_path,
525
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
526
+ start_frame=0,
527
+ initial_condition_frames=8,
528
+ frames_per_generation=4,
529
+ total_frames_to_generate=32,
530
+ max_history_frames=49,
531
+ device="cuda",
532
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
533
+ modality_type="sekai", # "sekai" 或 "nuscenes"
534
+ use_real_poses=True,
535
+ scene_info_path=None, # 对于NuScenes数据集
536
+ # CFG参数
537
+ use_camera_cfg=True,
538
+ camera_guidance_scale=2.0,
539
+ text_guidance_scale=1.0,
540
+ # MoE参数
541
+ moe_num_experts=4,
542
+ moe_top_k=2,
543
+ moe_hidden_dim=None
544
+ ):
545
+ """
546
+ MoE FramePack滑动窗口视频生成 - 支持多模态
547
+ """
548
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
549
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
550
+ print(f"模态类型: {modality_type}")
551
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
552
+ print(f"Text guidance scale: {text_guidance_scale}")
553
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
554
+
555
+ # 1. 模型初始化
556
+ replace_dit_model_in_manager()
557
+
558
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
559
+ model_manager.load_models([
560
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
561
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
562
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
563
+ ])
564
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
565
+
566
+ # 2. 添加传统camera编码器(兼容性)
567
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
568
+ for block in pipe.dit.blocks:
569
+ block.cam_encoder = nn.Linear(13, dim)
570
+ block.projector = nn.Linear(dim, dim)
571
+ block.cam_encoder.weight.data.zero_()
572
+ block.cam_encoder.bias.data.zero_()
573
+ block.projector.weight = nn.Parameter(torch.eye(dim))
574
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
575
+
576
+ # 3. 添加FramePack组件
577
+ add_framepack_components(pipe.dit)
578
+
579
+ # 4. 添加MoE组件
580
+ moe_config = {
581
+ "num_experts": moe_num_experts,
582
+ "top_k": moe_top_k,
583
+ "hidden_dim": moe_hidden_dim or dim * 2,
584
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
585
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
586
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
587
+ }
588
+ add_moe_components(pipe.dit, moe_config)
589
+
590
+ # 5. 加载训练好的权重
591
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
592
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
593
+ pipe = pipe.to(device)
594
+ model_dtype = next(pipe.dit.parameters()).dtype
595
+
596
+ if hasattr(pipe.dit, 'clean_x_embedder'):
597
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
598
+
599
+ pipe.scheduler.set_timesteps(50)
600
+
601
+ # 6. 加载初始条件
602
+ print("Loading initial condition frames...")
603
+ initial_latents, encoded_data = load_encoded_video_from_pth(
604
+ condition_pth_path,
605
+ start_frame=start_frame,
606
+ num_frames=initial_condition_frames
607
+ )
608
+
609
+ # 空间裁剪
610
+ target_height, target_width = 60, 104
611
+ C, T, H, W = initial_latents.shape
612
+
613
+ if H > target_height or W > target_width:
614
+ h_start = (H - target_height) // 2
615
+ w_start = (W - target_width) // 2
616
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
617
+ H, W = target_height, target_width
618
+
619
+ history_latents = initial_latents.to(device, dtype=model_dtype)
620
+
621
+ print(f"初始history_latents shape: {history_latents.shape}")
622
+
623
+ # 7. 编码prompt - 支持CFG
624
+ if text_guidance_scale > 1.0:
625
+ prompt_emb_pos = pipe.encode_prompt(prompt)
626
+ prompt_emb_neg = pipe.encode_prompt("")
627
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
628
+ else:
629
+ prompt_emb_pos = pipe.encode_prompt(prompt)
630
+ prompt_emb_neg = None
631
+ print("不使用Text CFG")
632
+
633
+ # 8. 加载场景信息(对于NuScenes)
634
+ scene_info = None
635
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
636
+ with open(scene_info_path, 'r') as f:
637
+ scene_info = json.load(f)
638
+ print(f"加载NuScenes场景信息: {scene_info_path}")
639
+
640
+ # 9. 预生成完整的camera embedding序列
641
+ if modality_type == "sekai":
642
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
643
+ encoded_data.get('cam_emb', None),
644
+ 0,
645
+ max_history_frames,
646
+ 0,
647
+ 0,
648
+ use_real_poses=use_real_poses
649
+ ).to(device, dtype=model_dtype)
650
+ elif modality_type == "nuscenes":
651
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
652
+ scene_info,
653
+ 0,
654
+ max_history_frames,
655
+ 0
656
+ ).to(device, dtype=model_dtype)
657
+ elif modality_type == "openx":
658
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
659
+ encoded_data,
660
+ 0,
661
+ max_history_frames,
662
+ 0,
663
+ use_real_poses=use_real_poses
664
+ ).to(device, dtype=model_dtype)
665
+ else:
666
+ raise ValueError(f"不支持的模态类型: {modality_type}")
667
+
668
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
669
+
670
+ # 10. 为Camera CFG创建无条件的camera embedding
671
+ if use_camera_cfg:
672
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
673
+ print(f"创建无条件camera embedding用于CFG")
674
+
675
+ # 11. 滑动窗口生成循环
676
+ total_generated = 0
677
+ all_generated_frames = []
678
+
679
+ while total_generated < total_frames_to_generate:
680
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
681
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
682
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
683
+
684
+ # FramePack数据准备 - MoE版本
685
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
686
+ history_latents,
687
+ current_generation,
688
+ camera_embedding_full,
689
+ start_frame,
690
+ modality_type,
691
+ max_history_frames
692
+ )
693
+
694
+ # 准备输入
695
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
696
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
697
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
698
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
699
+
700
+ # 准备modality_inputs
701
+ modality_inputs = {modality_type: camera_embedding}
702
+
703
+ # 为CFG准备无条件camera embedding
704
+ if use_camera_cfg:
705
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
706
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
707
+
708
+ # 索引处理
709
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
710
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
711
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
712
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
713
+
714
+ # 初始化要生成的latents
715
+ new_latents = torch.randn(
716
+ 1, C, current_generation, H, W,
717
+ device=device, dtype=model_dtype
718
+ )
719
+
720
+ extra_input = pipe.prepare_extra_input(new_latents)
721
+
722
+ print(f"Camera embedding shape: {camera_embedding.shape}")
723
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
724
+
725
+ # 去噪循环 - 支持CFG
726
+ timesteps = pipe.scheduler.timesteps
727
+
728
+ for i, timestep in enumerate(timesteps):
729
+ if i % 10 == 0:
730
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
731
+
732
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
733
+
734
+ with torch.no_grad():
735
+ # CFG推理
736
+ if use_camera_cfg and camera_guidance_scale > 1.0:
737
+ # 条件预测(有camera)
738
+ noise_pred_cond, moe_loss = pipe.dit(
739
+ new_latents,
740
+ timestep=timestep_tensor,
741
+ cam_emb=camera_embedding,
742
+ modality_inputs=modality_inputs, # MoE模态输入
743
+ latent_indices=latent_indices,
744
+ clean_latents=clean_latents,
745
+ clean_latent_indices=clean_latent_indices,
746
+ clean_latents_2x=clean_latents_2x,
747
+ clean_latent_2x_indices=clean_latent_2x_indices,
748
+ clean_latents_4x=clean_latents_4x,
749
+ clean_latent_4x_indices=clean_latent_4x_indices,
750
+ **prompt_emb_pos,
751
+ **extra_input
752
+ )
753
+
754
+ # 无条件预测(无camera)
755
+ noise_pred_uncond, moe_loss = pipe.dit(
756
+ new_latents,
757
+ timestep=timestep_tensor,
758
+ cam_emb=camera_embedding_uncond_batch,
759
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
760
+ latent_indices=latent_indices,
761
+ clean_latents=clean_latents,
762
+ clean_latent_indices=clean_latent_indices,
763
+ clean_latents_2x=clean_latents_2x,
764
+ clean_latent_2x_indices=clean_latent_2x_indices,
765
+ clean_latents_4x=clean_latents_4x,
766
+ clean_latent_4x_indices=clean_latent_4x_indices,
767
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
768
+ **extra_input
769
+ )
770
+
771
+ # Camera CFG
772
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
773
+
774
+ # 如果同时使用Text CFG
775
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
776
+ noise_pred_text_uncond, moe_loss = pipe.dit(
777
+ new_latents,
778
+ timestep=timestep_tensor,
779
+ cam_emb=camera_embedding,
780
+ modality_inputs=modality_inputs,
781
+ latent_indices=latent_indices,
782
+ clean_latents=clean_latents,
783
+ clean_latent_indices=clean_latent_indices,
784
+ clean_latents_2x=clean_latents_2x,
785
+ clean_latent_2x_indices=clean_latent_2x_indices,
786
+ clean_latents_4x=clean_latents_4x,
787
+ clean_latent_4x_indices=clean_latent_4x_indices,
788
+ **prompt_emb_neg,
789
+ **extra_input
790
+ )
791
+
792
+ # 应用Text CFG到已经应用Camera CFG的结果
793
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
794
+
795
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
796
+ # 只使用Text CFG
797
+ noise_pred_cond, moe_loss = pipe.dit(
798
+ new_latents,
799
+ timestep=timestep_tensor,
800
+ cam_emb=camera_embedding,
801
+ modality_inputs=modality_inputs,
802
+ latent_indices=latent_indices,
803
+ clean_latents=clean_latents,
804
+ clean_latent_indices=clean_latent_indices,
805
+ clean_latents_2x=clean_latents_2x,
806
+ clean_latent_2x_indices=clean_latent_2x_indices,
807
+ clean_latents_4x=clean_latents_4x,
808
+ clean_latent_4x_indices=clean_latent_4x_indices,
809
+ **prompt_emb_pos,
810
+ **extra_input
811
+ )
812
+
813
+ noise_pred_uncond, moe_loss = pipe.dit(
814
+ new_latents,
815
+ timestep=timestep_tensor,
816
+ cam_emb=camera_embedding,
817
+ modality_inputs=modality_inputs,
818
+ latent_indices=latent_indices,
819
+ clean_latents=clean_latents,
820
+ clean_latent_indices=clean_latent_indices,
821
+ clean_latents_2x=clean_latents_2x,
822
+ clean_latent_2x_indices=clean_latent_2x_indices,
823
+ clean_latents_4x=clean_latents_4x,
824
+ clean_latent_4x_indices=clean_latent_4x_indices,
825
+ **prompt_emb_neg,
826
+ **extra_input
827
+ )
828
+
829
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
830
+
831
+ else:
832
+ # 标准推理(无CFG)
833
+ noise_pred, moe_loss = pipe.dit(
834
+ new_latents,
835
+ timestep=timestep_tensor,
836
+ cam_emb=camera_embedding,
837
+ modality_inputs=modality_inputs, # MoE模态输入
838
+ latent_indices=latent_indices,
839
+ clean_latents=clean_latents,
840
+ clean_latent_indices=clean_latent_indices,
841
+ clean_latents_2x=clean_latents_2x,
842
+ clean_latent_2x_indices=clean_latent_2x_indices,
843
+ clean_latents_4x=clean_latents_4x,
844
+ clean_latent_4x_indices=clean_latent_4x_indices,
845
+ **prompt_emb_pos,
846
+ **extra_input
847
+ )
848
+
849
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
850
+
851
+ # 更新历史
852
+ new_latents_squeezed = new_latents.squeeze(0)
853
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
854
+
855
+ # 维护滑动窗口
856
+ if history_latents.shape[1] > max_history_frames:
857
+ first_frame = history_latents[:, 0:1, :, :]
858
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
859
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
860
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
861
+
862
+ print(f"更新后history_latents shape: {history_latents.shape}")
863
+
864
+ all_generated_frames.append(new_latents_squeezed)
865
+ total_generated += current_generation
866
+
867
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
868
+
869
+ # 12. 解码和保存
870
+ print("\n🔧 解码生成的视频...")
871
+
872
+ all_generated = torch.cat(all_generated_frames, dim=1)
873
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
874
+
875
+ print(f"最终视频shape: {final_video.shape}")
876
+
877
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
878
+
879
+ print(f"Saving video to {output_path}")
880
+
881
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
882
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
883
+ video_np = (video_np * 255).astype(np.uint8)
884
+
885
+ with imageio.get_writer(output_path, fps=20) as writer:
886
+ for frame in video_np:
887
+ writer.append_data(frame)
888
+
889
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
890
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
891
+ print(f"使用模态: {modality_type}")
892
+
893
+
894
+ def main():
895
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
896
+
897
+ # 基��参数
898
+ parser.add_argument("--condition_pth", type=str,
899
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
900
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
901
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
902
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
903
+ parser.add_argument("--start_frame", type=int, default=0)
904
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
905
+ parser.add_argument("--frames_per_generation", type=int, default=8)
906
+ parser.add_argument("--total_frames_to_generate", type=int, default=40)
907
+ parser.add_argument("--max_history_frames", type=int, default=100)
908
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
909
+ parser.add_argument("--dit_path", type=str,
910
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1000_moe.ckpt")
911
+ parser.add_argument("--output_path", type=str,
912
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
913
+ parser.add_argument("--prompt", type=str,
914
+ default="A drone flying scene in a game world")
915
+ parser.add_argument("--device", type=str, default="cuda")
916
+
917
+ # 模态类型参数
918
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
919
+ help="模态类型:sekai 或 nuscenes 或 openx")
920
+ parser.add_argument("--scene_info_path", type=str, default=None,
921
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
922
+
923
+ # CFG参数
924
+ parser.add_argument("--use_camera_cfg", default=True,
925
+ help="使用Camera CFG")
926
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
927
+ help="Camera guidance scale for CFG")
928
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
929
+ help="Text guidance scale for CFG")
930
+
931
+ # MoE参数
932
+ parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
933
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
934
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
935
+
936
+ args = parser.parse_args()
937
+
938
+ print(f"🔧 MoE FramePack CFG生成设置:")
939
+ print(f"模态类型: {args.modality_type}")
940
+ print(f"Camera CFG: {args.use_camera_cfg}")
941
+ if args.use_camera_cfg:
942
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
943
+ print(f"Text guidance scale: {args.text_guidance_scale}")
944
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
945
+
946
+ # 验证NuScenes参数
947
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
948
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
949
+
950
+ inference_moe_framepack_sliding_window(
951
+ condition_pth_path=args.condition_pth,
952
+ dit_path=args.dit_path,
953
+ output_path=args.output_path,
954
+ start_frame=args.start_frame,
955
+ initial_condition_frames=args.initial_condition_frames,
956
+ frames_per_generation=args.frames_per_generation,
957
+ total_frames_to_generate=args.total_frames_to_generate,
958
+ max_history_frames=args.max_history_frames,
959
+ device=args.device,
960
+ prompt=args.prompt,
961
+ modality_type=args.modality_type,
962
+ use_real_poses=args.use_real_poses,
963
+ scene_info_path=args.scene_info_path,
964
+ # CFG参数
965
+ use_camera_cfg=args.use_camera_cfg,
966
+ camera_guidance_scale=args.camera_guidance_scale,
967
+ text_guidance_scale=args.text_guidance_scale,
968
+ # MoE参数
969
+ moe_num_experts=args.moe_num_experts,
970
+ moe_top_k=args.moe_top_k,
971
+ moe_hidden_dim=args.moe_hidden_dim
972
+ )
973
+
974
+
975
+ if __name__ == "__main__":
976
+ main()
scripts/infer_nus.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import json
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ import argparse
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import torch.nn as nn
12
+ from pose_classifier import PoseClassifier
13
+
14
+
15
+ def load_video_frames(video_path, num_frames=20, height=900, width=1600):
16
+ """Load video frames and preprocess them"""
17
+ frame_process = v2.Compose([
18
+ # v2.CenterCrop(size=(height, width)),
19
+ # v2.Resize(size=(height, width), antialias=True),
20
+ v2.ToTensor(),
21
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
22
+ ])
23
+
24
+ def crop_and_resize(image):
25
+ w, h = image.size
26
+ # scale = max(width / w, height / h)
27
+ image = v2.functional.resize(
28
+ image,
29
+ (round(480), round(832)),
30
+ interpolation=v2.InterpolationMode.BILINEAR
31
+ )
32
+ return image
33
+
34
+ reader = imageio.get_reader(video_path)
35
+ frames = []
36
+
37
+ for i, frame_data in enumerate(reader):
38
+ if i >= num_frames:
39
+ break
40
+ frame = Image.fromarray(frame_data)
41
+ frame = crop_and_resize(frame)
42
+ frame = frame_process(frame)
43
+ frames.append(frame)
44
+
45
+ reader.close()
46
+
47
+ if len(frames) == 0:
48
+ return None
49
+
50
+ frames = torch.stack(frames, dim=0)
51
+ frames = rearrange(frames, "T C H W -> C T H W")
52
+ return frames
53
+
54
+ def calculate_relative_rotation(current_rotation, reference_rotation):
55
+ """计算相对旋转四元数"""
56
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
57
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
58
+
59
+ # 计算参考旋转的逆 (q_ref^-1)
60
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
61
+
62
+ # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
63
+ w1, x1, y1, z1 = q_ref_inv
64
+ w2, x2, y2, z2 = q_current
65
+
66
+ relative_rotation = torch.tensor([
67
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
68
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
69
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
70
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
71
+ ])
72
+
73
+ return relative_rotation
74
+
75
+ def generate_direction_poses(direction="left", target_frames=10, condition_frames=20):
76
+ """
77
+ 根据指定方向生成pose类别embedding,包含condition和target帧
78
+ Args:
79
+ direction: 'forward', 'backward', 'left_turn', 'right_turn'
80
+ target_frames: 目标帧数
81
+ condition_frames: 条件帧数
82
+ """
83
+ classifier = PoseClassifier()
84
+
85
+ total_frames = condition_frames + target_frames
86
+ print(f"conditon{condition_frames}")
87
+ print(f"target{target_frames}")
88
+ poses = []
89
+
90
+ # 🔧 生成condition帧的pose(相对稳定的前向运动)
91
+ for i in range(condition_frames):
92
+ t = i / max(1, condition_frames - 1) # 0 to 1
93
+
94
+ # condition帧保持相对稳定的前向运动
95
+ translation = [-t * 0.5, 0.0, 0.0] # 缓慢前进
96
+ rotation = [1.0, 0.0, 0.0, 0.0] # 无旋转
97
+ frame_type = 0.0 # condition
98
+
99
+ pose_vec = translation + rotation + [frame_type] # 8D vector
100
+ poses.append(pose_vec)
101
+
102
+ # 🔧 生成target帧的pose(根据指定方向)
103
+ for i in range(target_frames):
104
+ t = i / max(1, target_frames - 1) # 0 to 1
105
+
106
+ if direction == "forward":
107
+ # 前进:x负方向移动,无旋转
108
+ translation = [-(condition_frames * 0.5 + t * 2.0), 0.0, 0.0]
109
+ rotation = [1.0, 0.0, 0.0, 0.0] # 单位四元数
110
+
111
+ elif direction == "backward":
112
+ # 后退:x正方向移动,无旋转
113
+ translation = [-(condition_frames * 0.5) + t * 2.0, 0.0, 0.0]
114
+ rotation = [1.0, 0.0, 0.0, 0.0]
115
+
116
+ elif direction == "left_turn":
117
+ # 左转:前进 + 绕z轴正向旋转
118
+ translation = [-(condition_frames * 0.5 + t * 1.5), t * 0.5, 0.0] # 前进并稍微左移
119
+ yaw = t * 0.3 # 左转角度(弧度)
120
+ rotation = [
121
+ np.cos(yaw/2), # w
122
+ 0.0, # x
123
+ 0.0, # y
124
+ np.sin(yaw/2) # z (左转为正)
125
+ ]
126
+
127
+ elif direction == "right_turn":
128
+ # 右转:前进 + 绕z轴负向旋转
129
+ translation = [-(condition_frames * 0.5 + t * 1.5), -t * 0.5, 0.0] # 前进并稍微右移
130
+ yaw = -t * 0.3 # 右转角度(弧度)
131
+ rotation = [
132
+ np.cos(abs(yaw)/2), # w
133
+ 0.0, # x
134
+ 0.0, # y
135
+ np.sin(yaw/2) # z (右转为负)
136
+ ]
137
+ else:
138
+ raise ValueError(f"Unknown direction: {direction}")
139
+
140
+ frame_type = 1.0 # target
141
+ pose_vec = translation + rotation + [frame_type] # 8D vector
142
+ poses.append(pose_vec)
143
+
144
+ pose_sequence = torch.tensor(poses, dtype=torch.float32)
145
+
146
+ # 🔧 只对target部分进行分类(前7维,去掉frame type)
147
+ target_pose_sequence = pose_sequence[condition_frames:, :7]
148
+
149
+ # 🔧 使用增强的embedding生成方法
150
+ condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # condition都是forward
151
+ target_classes = classifier.classify_pose_sequence(target_pose_sequence)
152
+ full_classes = torch.cat([condition_classes, target_classes], dim=0)
153
+
154
+ # 创建增强的embedding
155
+ class_embeddings = create_enhanced_class_embedding_for_inference(
156
+ full_classes, pose_sequence, embed_dim=512
157
+ )
158
+
159
+ print(f"Generated {direction} poses:")
160
+ print(f" Total frames: {total_frames} (condition: {condition_frames}, target: {target_frames})")
161
+ analysis = classifier.analyze_pose_sequence(target_pose_sequence)
162
+ print(f" Target class distribution: {analysis['class_distribution']}")
163
+ print(f" Target motion segments: {len(analysis['motion_segments'])}")
164
+
165
+ return class_embeddings
166
+
167
+ def create_enhanced_class_embedding_for_inference(class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor:
168
+ """推理时创建增强的类别embedding"""
169
+ num_classes = 4
170
+ num_frames = len(class_labels)
171
+
172
+ # 基础的方向embedding
173
+ direction_vectors = torch.tensor([
174
+ [1.0, 0.0, 0.0, 0.0], # forward
175
+ [-1.0, 0.0, 0.0, 0.0], # backward
176
+ [0.0, 1.0, 0.0, 0.0], # left_turn
177
+ [0.0, -1.0, 0.0, 0.0], # right_turn
178
+ ], dtype=torch.float32)
179
+
180
+ # One-hot编码
181
+ one_hot = torch.zeros(num_frames, num_classes)
182
+ one_hot.scatter_(1, class_labels.unsqueeze(1), 1)
183
+
184
+ # 基于方向向量的基础embedding
185
+ base_embeddings = one_hot @ direction_vectors # [num_frames, 4]
186
+
187
+ # 添加frame type信息
188
+ frame_types = pose_sequence[:, -1] # 最后一维是frame type
189
+ frame_type_embeddings = torch.zeros(num_frames, 2)
190
+ frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition
191
+ frame_type_embeddings[:, 1] = (frame_types == 1).float() # target
192
+
193
+ # 添加pose的几何信息
194
+ translations = pose_sequence[:, :3] # [num_frames, 3]
195
+ rotations = pose_sequence[:, 3:7] # [num_frames, 4]
196
+
197
+ # 组合所有特征
198
+ combined_features = torch.cat([
199
+ base_embeddings, # [num_frames, 4]
200
+ frame_type_embeddings, # [num_frames, 2]
201
+ translations, # [num_frames, 3]
202
+ rotations, # [num_frames, 4]
203
+ ], dim=1) # [num_frames, 13]
204
+
205
+ # 扩展到目标维度
206
+ if embed_dim > 13:
207
+ expand_matrix = torch.randn(13, embed_dim) * 0.1
208
+ expand_matrix[:13, :13] = torch.eye(13)
209
+ embeddings = combined_features @ expand_matrix
210
+ else:
211
+ embeddings = combined_features[:, :embed_dim]
212
+
213
+ return embeddings
214
+
215
+ def generate_poses_from_file(poses_path, target_frames=10):
216
+ """从poses.json文件生成类别embedding"""
217
+ classifier = PoseClassifier()
218
+
219
+ with open(poses_path, 'r') as f:
220
+ poses_data = json.load(f)
221
+
222
+ target_relative_poses = poses_data['target_relative_poses']
223
+
224
+ if not target_relative_poses:
225
+ print("No poses found in file, using forward direction")
226
+ return generate_direction_poses("forward", target_frames)
227
+
228
+ # 创建pose序列
229
+ pose_vecs = []
230
+ for i in range(target_frames):
231
+ if len(target_relative_poses) == 1:
232
+ pose_data = target_relative_poses[0]
233
+ else:
234
+ pose_idx = min(i * len(target_relative_poses) // target_frames,
235
+ len(target_relative_poses) - 1)
236
+ pose_data = target_relative_poses[pose_idx]
237
+
238
+ # 提取相对位移和旋转
239
+ translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
240
+ current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
241
+ reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
242
+
243
+ # 计算相对旋转
244
+ relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
245
+
246
+ # 组合为7D向量
247
+ pose_vec = torch.cat([translation, relative_rotation], dim=0)
248
+ pose_vecs.append(pose_vec)
249
+
250
+ pose_sequence = torch.stack(pose_vecs, dim=0)
251
+
252
+ # 使用分类器生成class embedding
253
+ class_embeddings = classifier.create_class_embedding(
254
+ classifier.classify_pose_sequence(pose_sequence),
255
+ embed_dim=512
256
+ )
257
+
258
+ print(f"Generated poses from file:")
259
+ analysis = classifier.analyze_pose_sequence(pose_sequence)
260
+ print(f" Class distribution: {analysis['class_distribution']}")
261
+ print(f" Motion segments: {len(analysis['motion_segments'])}")
262
+
263
+ return class_embeddings
264
+
265
+ def inference_nuscenes_video(
266
+ condition_video_path,
267
+ dit_path,
268
+ text_encoder_path,
269
+ vae_path,
270
+ output_path="nus/infer_results/output_nuscenes.mp4",
271
+ condition_frames=20,
272
+ target_frames=3,
273
+ height=900,
274
+ width=1600,
275
+ device="cuda",
276
+ prompt="A car driving scene captured by front camera",
277
+ poses_path=None,
278
+ direction="forward"
279
+ ):
280
+ """
281
+ 使用方向类别控制的推理函数 - 支持condition和target pose区分
282
+ """
283
+ os.makedirs(os.path.dirname(output_path),exist_ok=True)
284
+
285
+ print(f"Setting up models for {direction} movement...")
286
+
287
+ # 1. Load models (same as before)
288
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
289
+ model_manager.load_models([
290
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
291
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
292
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
293
+ ])
294
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
295
+
296
+ # Add camera components to DiT
297
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
298
+ for block in pipe.dit.blocks:
299
+ block.cam_encoder = nn.Linear(512, dim) # 保持512维embedding
300
+ block.projector = nn.Linear(dim, dim)
301
+ block.cam_encoder.weight.data.zero_()
302
+ block.cam_encoder.bias.data.zero_()
303
+ block.projector.weight = nn.Parameter(torch.eye(dim))
304
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
305
+
306
+ # Load trained DiT weights
307
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
308
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
309
+ pipe = pipe.to(device)
310
+ pipe.scheduler.set_timesteps(50)
311
+
312
+ print("Loading condition video...")
313
+
314
+ # Load condition video
315
+ condition_video = load_video_frames(
316
+ condition_video_path,
317
+ num_frames=condition_frames,
318
+ height=height,
319
+ width=width
320
+ )
321
+
322
+ if condition_video is None:
323
+ raise ValueError(f"Failed to load condition video from {condition_video_path}")
324
+
325
+ condition_video = condition_video.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
326
+
327
+ print("Processing poses...")
328
+
329
+ # 🔧 修改:生成包含condition和target的pose embedding
330
+ print(f"Generating {direction} movement poses...")
331
+ camera_embedding = generate_direction_poses(
332
+ direction=direction,
333
+ target_frames=target_frames,
334
+ condition_frames=int(condition_frames/4) # 压缩后的condition帧数
335
+ )
336
+
337
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
338
+
339
+ print(f"Camera embedding shape: {camera_embedding.shape}")
340
+ print(f"Generated poses for direction: {direction}")
341
+
342
+ print("Encoding inputs...")
343
+
344
+ # Encode text prompt
345
+ prompt_emb = pipe.encode_prompt(prompt)
346
+
347
+ # Encode condition video
348
+ condition_latents = pipe.encode_video(condition_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))[0]
349
+
350
+ print("Generating video...")
351
+
352
+ # Generate target latents
353
+ batch_size = 1
354
+ channels = condition_latents.shape[0]
355
+ latent_height = condition_latents.shape[2]
356
+ latent_width = condition_latents.shape[3]
357
+ target_height, target_width = 60, 104 # 根据你的需求调整
358
+
359
+ if latent_height > target_height or latent_width > target_width:
360
+ # 中心裁剪
361
+ h_start = (latent_height - target_height) // 2
362
+ w_start = (latent_width - target_width) // 2
363
+ condition_latents = condition_latents[:, :,
364
+ h_start:h_start+target_height,
365
+ w_start:w_start+target_width]
366
+ latent_height = target_height
367
+ latent_width = target_width
368
+ condition_latents = condition_latents.to(device, dtype=pipe.torch_dtype)
369
+ condition_latents = condition_latents.unsqueeze(0)
370
+ condition_latents = condition_latents + 0.05 * torch.randn_like(condition_latents) # 添加少量噪声以增加多样性
371
+
372
+ # Initialize target latents with noise
373
+ target_latents = torch.randn(
374
+ batch_size, channels, target_frames, latent_height, latent_width,
375
+ device=device, dtype=pipe.torch_dtype
376
+ )
377
+ print(target_latents.shape)
378
+ print(camera_embedding.shape)
379
+ # Combine condition and target latents
380
+ combined_latents = torch.cat([condition_latents, target_latents], dim=2)
381
+ print(combined_latents.shape)
382
+
383
+ # Prepare extra inputs
384
+ extra_input = pipe.prepare_extra_input(combined_latents)
385
+
386
+ # Denoising loop
387
+ timesteps = pipe.scheduler.timesteps
388
+
389
+ for i, timestep in enumerate(timesteps):
390
+ print(f"Denoising step {i+1}/{len(timesteps)}")
391
+
392
+ # Prepare timestep
393
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
394
+
395
+ # Predict noise
396
+ with torch.no_grad():
397
+ noise_pred = pipe.dit(
398
+ combined_latents,
399
+ timestep=timestep_tensor,
400
+ cam_emb=camera_embedding,
401
+ **prompt_emb,
402
+ **extra_input
403
+ )
404
+
405
+ # Update only target part
406
+ target_noise_pred = noise_pred[:, :, int(condition_frames/4):, :, :]
407
+ target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
408
+
409
+ # Update combined latents
410
+ combined_latents[:, :, int(condition_frames/4):, :, :] = target_latents
411
+
412
+ print("Decoding video...")
413
+
414
+ # Decode final video
415
+ final_video = torch.cat([condition_latents, target_latents], dim=2)
416
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
417
+
418
+ # Save video
419
+ print(f"Saving video to {output_path}")
420
+
421
+ # Convert to numpy and save
422
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() # 转换为 Float32
423
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
424
+ video_np = (video_np * 255).astype(np.uint8)
425
+
426
+ with imageio.get_writer(output_path, fps=20) as writer:
427
+ for frame in video_np:
428
+ writer.append_data(frame)
429
+
430
+ print(f"Video generation completed! Saved to {output_path}")
431
+
432
+ def main():
433
+ parser = argparse.ArgumentParser(description="NuScenes Video Generation Inference with Direction Control")
434
+ parser.add_argument("--condition_video", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032/right.mp4",
435
+ help="Path to condition video")
436
+ parser.add_argument("--direction", type=str, default="left_turn",
437
+ choices=["forward", "backward", "left_turn", "right_turn"],
438
+ help="Direction of camera movement")
439
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
440
+ help="Path to trained DiT checkpoint")
441
+ parser.add_argument("--text_encoder_path", type=str,
442
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
443
+ help="Path to text encoder")
444
+ parser.add_argument("--vae_path", type=str,
445
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
446
+ help="Path to VAE")
447
+ parser.add_argument("--output_path", type=str, default="nus/infer_results-15000/right_left.mp4",
448
+ help="Output video path")
449
+ parser.add_argument("--poses_path", type=str, default=None,
450
+ help="Path to poses.json file (optional, will use direction if not provided)")
451
+ parser.add_argument("--prompt", type=str,
452
+ default="A car driving scene captured by front camera",
453
+ help="Text prompt for generation")
454
+ parser.add_argument("--condition_frames", type=int, default=40,
455
+ help="Number of condition frames")
456
+ # 这个是原始帧数
457
+ parser.add_argument("--target_frames", type=int, default=8,
458
+ help="Number of target frames to generate")
459
+ # 这个要除以4
460
+ parser.add_argument("--height", type=int, default=900,
461
+ help="Video height")
462
+ parser.add_argument("--width", type=int, default=1600,
463
+ help="Video width")
464
+ parser.add_argument("--device", type=str, default="cuda",
465
+ help="Device to run inference on")
466
+
467
+ args = parser.parse_args()
468
+
469
+ condition_video_path = args.condition_video
470
+ input_filename = os.path.basename(condition_video_path)
471
+ output_dir = "nus/infer_results"
472
+ os.makedirs(output_dir, exist_ok=True)
473
+
474
+ # 🔧 修改:在输出文件名中包含方向信息
475
+ if args.output_path is None:
476
+ name_parts = os.path.splitext(input_filename)
477
+ output_filename = f"{name_parts[0]}_{args.direction}{name_parts[1]}"
478
+ output_path = os.path.join(output_dir, output_filename)
479
+ else:
480
+ output_path = args.output_path
481
+
482
+ print(f"Output video will be saved to: {output_path}")
483
+ inference_nuscenes_video(
484
+ condition_video_path=args.condition_video,
485
+ dit_path=args.dit_path,
486
+ text_encoder_path=args.text_encoder_path,
487
+ vae_path=args.vae_path,
488
+ output_path=output_path,
489
+ condition_frames=args.condition_frames,
490
+ target_frames=args.target_frames,
491
+ height=args.height,
492
+ width=args.width,
493
+ device=args.device,
494
+ prompt=args.prompt,
495
+ poses_path=args.poses_path,
496
+ direction=args.direction # 🔧 新增
497
+ )
498
+
499
+ if __name__ == "__main__":
500
+ main()
scripts/infer_openx.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
3
+ from torchvision.transforms import v2
4
+ from einops import rearrange
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import argparse
9
+ import numpy as np
10
+ import imageio
11
+ import copy
12
+ import random
13
+
14
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
15
+ """从pth文件加载预编码的视频数据"""
16
+ print(f"Loading encoded video from {pth_path}")
17
+
18
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
19
+ full_latents = encoded_data['latents'] # [C, T, H, W]
20
+
21
+ print(f"Full latents shape: {full_latents.shape}")
22
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
23
+
24
+ if start_frame + num_frames > full_latents.shape[1]:
25
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
26
+
27
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
28
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
29
+
30
+ return condition_latents, encoded_data
31
+
32
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
33
+ """计算相机B相对于相机A的相对位姿矩阵"""
34
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
35
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
36
+
37
+ if use_torch:
38
+ if not isinstance(pose_a, torch.Tensor):
39
+ pose_a = torch.from_numpy(pose_a).float()
40
+ if not isinstance(pose_b, torch.Tensor):
41
+ pose_b = torch.from_numpy(pose_b).float()
42
+
43
+ pose_a_inv = torch.inverse(pose_a)
44
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
45
+ else:
46
+ if not isinstance(pose_a, np.ndarray):
47
+ pose_a = np.array(pose_a, dtype=np.float32)
48
+ if not isinstance(pose_b, np.ndarray):
49
+ pose_b = np.array(pose_b, dtype=np.float32)
50
+
51
+ pose_a_inv = np.linalg.inv(pose_a)
52
+ relative_pose = np.matmul(pose_b, pose_a_inv)
53
+
54
+ return relative_pose
55
+
56
+ def replace_dit_model_in_manager():
57
+ """在模型加载前替换DiT模型类"""
58
+ from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
59
+ from diffsynth.configs.model_config import model_loader_configs
60
+
61
+ # 修改model_loader_configs中的配置
62
+ for i, config in enumerate(model_loader_configs):
63
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
64
+
65
+ # 检查是否包含wan_video_dit模型
66
+ if 'wan_video_dit' in model_names:
67
+ # 找到wan_video_dit的索引并替换为WanModelFuture
68
+ new_model_names = []
69
+ new_model_classes = []
70
+
71
+ for name, cls in zip(model_names, model_classes):
72
+ if name == 'wan_video_dit':
73
+ new_model_names.append(name) # 保持名称不变
74
+ new_model_classes.append(WanModelFuture) # 替换为新的类
75
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture")
76
+ else:
77
+ new_model_names.append(name)
78
+ new_model_classes.append(cls)
79
+
80
+ # 更新配置
81
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
82
+
83
+ def add_framepack_components(dit_model):
84
+ """添加FramePack相关组件"""
85
+ if not hasattr(dit_model, 'clean_x_embedder'):
86
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
87
+
88
+ class CleanXEmbedder(nn.Module):
89
+ def __init__(self, inner_dim):
90
+ super().__init__()
91
+ # 参考hunyuan_video_packed.py的设计
92
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
93
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
94
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
95
+
96
+ def forward(self, x, scale="1x"):
97
+ if scale == "1x":
98
+ return self.proj(x)
99
+ elif scale == "2x":
100
+ return self.proj_2x(x)
101
+ elif scale == "4x":
102
+ return self.proj_4x(x)
103
+ else:
104
+ raise ValueError(f"Unsupported scale: {scale}")
105
+
106
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
107
+ model_dtype = next(dit_model.parameters()).dtype
108
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
109
+ print("✅ 添加了FramePack的clean_x_embedder组件")
110
+
111
+ def generate_openx_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
112
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
113
+ time_compression_ratio = 4
114
+
115
+ # 计算FramePack实际需要的camera帧数
116
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
117
+
118
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
119
+ print("🔧 使用真实OpenX camera数据")
120
+ cam_extrinsic = cam_data['extrinsic']
121
+
122
+ # 确保生成足够长的camera序列
123
+ max_needed_frames = max(
124
+ start_frame + current_history_length + new_frames,
125
+ framepack_needed_frames,
126
+ 30
127
+ )
128
+
129
+ print(f"🔧 计算OpenX camera序列长度:")
130
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
131
+ print(f" - FramePack需求: {framepack_needed_frames}")
132
+ print(f" - 最终生成: {max_needed_frames}")
133
+
134
+ relative_poses = []
135
+ for i in range(max_needed_frames):
136
+ # OpenX特有:每隔4帧
137
+ frame_idx = i * time_compression_ratio
138
+ next_frame_idx = frame_idx + time_compression_ratio
139
+
140
+ if next_frame_idx < len(cam_extrinsic):
141
+ cam_prev = cam_extrinsic[frame_idx]
142
+ cam_next = cam_extrinsic[next_frame_idx]
143
+ relative_cam = compute_relative_pose(cam_prev, cam_next)
144
+ relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
145
+ else:
146
+ # 超出范围,使用零运动
147
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
148
+ relative_poses.append(torch.zeros(3, 4))
149
+
150
+ pose_embedding = torch.stack(relative_poses, dim=0)
151
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
152
+
153
+ # 创建对应长度的mask序列
154
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
155
+ # 从start_frame到current_history_length标记为condition
156
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
157
+ mask[start_frame:condition_end] = 1.0
158
+
159
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
160
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
161
+ return camera_embedding.to(torch.bfloat16)
162
+
163
+ else:
164
+ print("🔧 使用OpenX合成camera数据")
165
+
166
+ max_needed_frames = max(
167
+ start_frame + current_history_length + new_frames,
168
+ framepack_needed_frames,
169
+ 30
170
+ )
171
+
172
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
173
+ relative_poses = []
174
+ for i in range(max_needed_frames):
175
+ # OpenX机器人操作模式 - 稳定的小幅度运动
176
+ # 模拟机器人手臂的精细操作
177
+ forward_speed = 0.001 # 每帧前进距离(很小,因为是精细操作)
178
+ lateral_motion = 0.0005 * np.sin(i * 0.05) # 轻微的左右移动
179
+ vertical_motion = 0.0003 * np.cos(i * 0.1) # 轻微的上下移动
180
+
181
+ # 旋转变化(模拟视角微调)
182
+ yaw_change = 0.01 * np.sin(i * 0.03) # 轻微的偏航
183
+ pitch_change = 0.008 * np.cos(i * 0.04) # 轻微的俯仰
184
+
185
+ pose = np.eye(4, dtype=np.float32)
186
+
187
+ # 旋转矩阵(绕Y轴和X轴的小角度旋转)
188
+ cos_yaw = np.cos(yaw_change)
189
+ sin_yaw = np.sin(yaw_change)
190
+ cos_pitch = np.cos(pitch_change)
191
+ sin_pitch = np.sin(pitch_change)
192
+
193
+ # 组合旋转(先pitch后yaw)
194
+ pose[0, 0] = cos_yaw
195
+ pose[0, 2] = sin_yaw
196
+ pose[1, 1] = cos_pitch
197
+ pose[1, 2] = -sin_pitch
198
+ pose[2, 0] = -sin_yaw
199
+ pose[2, 1] = sin_pitch
200
+ pose[2, 2] = cos_yaw * cos_pitch
201
+
202
+ # 平移(精细操作的小幅度移动)
203
+ pose[0, 3] = lateral_motion # X轴(左右)
204
+ pose[1, 3] = vertical_motion # Y轴(上下)
205
+ pose[2, 3] = -forward_speed # Z轴(前后,负值表示前进)
206
+
207
+ relative_pose = pose[:3, :]
208
+ relative_poses.append(torch.as_tensor(relative_pose))
209
+
210
+ pose_embedding = torch.stack(relative_poses, dim=0)
211
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
212
+
213
+ # 创建对应长度的mask序列
214
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
215
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
216
+ mask[start_frame:condition_end] = 1.0
217
+
218
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
219
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
220
+ return camera_embedding.to(torch.bfloat16)
221
+
222
+ def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
223
+ """FramePack滑动��口机制 - OpenX版本"""
224
+ # history_latents: [C, T, H, W] 当前的历史latents
225
+ C, T, H, W = history_latents.shape
226
+
227
+ # 固定索引结构(这决定了需要的camera帧数)
228
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
229
+ indices = torch.arange(0, total_indices_length)
230
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
231
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
232
+ indices.split(split_sizes, dim=0)
233
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
234
+
235
+ # 检查camera长度是否足够
236
+ if camera_embedding_full.shape[0] < total_indices_length:
237
+ shortage = total_indices_length - camera_embedding_full.shape[0]
238
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
239
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
240
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
241
+
242
+ # 从完整camera序列中选取对应部分
243
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
244
+
245
+ # 根据当前history length重新设置mask
246
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
247
+
248
+ # 设置condition mask:前19帧根据实际历史长度决定
249
+ if T > 0:
250
+ available_frames = min(T, 19)
251
+ start_pos = 19 - available_frames
252
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
253
+
254
+ print(f"🔧 OpenX Camera mask更新:")
255
+ print(f" - 历史帧数: {T}")
256
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
257
+
258
+ # 处理latents
259
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
260
+
261
+ if T > 0:
262
+ available_frames = min(T, 19)
263
+ start_pos = 19 - available_frames
264
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
265
+
266
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
267
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
268
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
269
+
270
+ if T > 0:
271
+ start_latent = history_latents[:, 0:1, :, :]
272
+ else:
273
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
274
+
275
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
276
+
277
+ return {
278
+ 'latent_indices': latent_indices,
279
+ 'clean_latents': clean_latents,
280
+ 'clean_latents_2x': clean_latents_2x,
281
+ 'clean_latents_4x': clean_latents_4x,
282
+ 'clean_latent_indices': clean_latent_indices,
283
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
284
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
285
+ 'camera_embedding': combined_camera,
286
+ 'current_length': T,
287
+ 'next_length': T + target_frames_to_generate
288
+ }
289
+
290
+ def inference_openx_framepack_sliding_window(
291
+ condition_pth_path,
292
+ dit_path,
293
+ output_path="openx_results/output_openx_framepack_sliding.mp4",
294
+ start_frame=0,
295
+ initial_condition_frames=8,
296
+ frames_per_generation=4,
297
+ total_frames_to_generate=32,
298
+ max_history_frames=49,
299
+ device="cuda",
300
+ prompt="A video of robotic manipulation task with camera movement",
301
+ use_real_poses=True,
302
+ # CFG参数
303
+ use_camera_cfg=True,
304
+ camera_guidance_scale=2.0,
305
+ text_guidance_scale=1.0
306
+ ):
307
+ """
308
+ OpenX FramePack滑动窗口视频生成
309
+ """
310
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
311
+ print(f"🔧 OpenX FramePack滑动窗口生成开始...")
312
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
313
+ print(f"Text guidance scale: {text_guidance_scale}")
314
+
315
+ # 1. 模型初始化
316
+ replace_dit_model_in_manager()
317
+
318
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
319
+ model_manager.load_models([
320
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
321
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
322
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
323
+ ])
324
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
325
+
326
+ # 2. 添加camera编码器
327
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
328
+ for block in pipe.dit.blocks:
329
+ block.cam_encoder = nn.Linear(13, dim)
330
+ block.projector = nn.Linear(dim, dim)
331
+ block.cam_encoder.weight.data.zero_()
332
+ block.cam_encoder.bias.data.zero_()
333
+ block.projector.weight = nn.Parameter(torch.eye(dim))
334
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
335
+
336
+ # 3. 添加FramePack组件
337
+ add_framepack_components(pipe.dit)
338
+
339
+ # 4. 加载训练好的权重
340
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
341
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
342
+ pipe = pipe.to(device)
343
+ model_dtype = next(pipe.dit.parameters()).dtype
344
+
345
+ if hasattr(pipe.dit, 'clean_x_embedder'):
346
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
347
+
348
+ pipe.scheduler.set_timesteps(50)
349
+
350
+ # 5. 加载初始条件
351
+ print("Loading initial condition frames...")
352
+ initial_latents, encoded_data = load_encoded_video_from_pth(
353
+ condition_pth_path,
354
+ start_frame=start_frame,
355
+ num_frames=initial_condition_frames
356
+ )
357
+
358
+ # 空间裁剪(适配OpenX数据尺寸)
359
+ target_height, target_width = 60, 104
360
+ C, T, H, W = initial_latents.shape
361
+
362
+ if H > target_height or W > target_width:
363
+ h_start = (H - target_height) // 2
364
+ w_start = (W - target_width) // 2
365
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
366
+ H, W = target_height, target_width
367
+
368
+ history_latents = initial_latents.to(device, dtype=model_dtype)
369
+
370
+ print(f"初始history_latents shape: {history_latents.shape}")
371
+
372
+ # 6. 编码prompt - 支持CFG
373
+ if text_guidance_scale > 1.0:
374
+ prompt_emb_pos = pipe.encode_prompt(prompt)
375
+ prompt_emb_neg = pipe.encode_prompt("")
376
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
377
+ else:
378
+ prompt_emb_pos = pipe.encode_prompt(prompt)
379
+ prompt_emb_neg = None
380
+ print("不使用Text CFG")
381
+
382
+ # 7. 预生成完整的camera embedding序列
383
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
384
+ encoded_data.get('cam_emb', None),
385
+ 0,
386
+ max_history_frames,
387
+ 0,
388
+ 0,
389
+ use_real_poses=use_real_poses
390
+ ).to(device, dtype=model_dtype)
391
+
392
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
393
+
394
+ # 8. 为Camera CFG创建无条件的camera embedding
395
+ if use_camera_cfg:
396
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
397
+ print(f"创建无条件camera embedding用于CFG")
398
+
399
+ # 9. 滑动窗口生成循环
400
+ total_generated = 0
401
+ all_generated_frames = []
402
+
403
+ while total_generated < total_frames_to_generate:
404
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
405
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
406
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
407
+
408
+ # FramePack数据准备 - OpenX版本
409
+ framepack_data = prepare_framepack_sliding_window_with_camera(
410
+ history_latents,
411
+ current_generation,
412
+ camera_embedding_full,
413
+ start_frame,
414
+ max_history_frames
415
+ )
416
+
417
+ # 准备输入
418
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
419
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
420
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
421
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
422
+
423
+ # 为CFG准备无条件camera embedding
424
+ if use_camera_cfg:
425
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
426
+
427
+ # 索引处理
428
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
429
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
430
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
431
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
432
+
433
+ # 初始化要生成的latents
434
+ new_latents = torch.randn(
435
+ 1, C, current_generation, H, W,
436
+ device=device, dtype=model_dtype
437
+ )
438
+
439
+ extra_input = pipe.prepare_extra_input(new_latents)
440
+
441
+ print(f"Camera embedding shape: {camera_embedding.shape}")
442
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
443
+
444
+ # 去噪循环 - 支持CFG
445
+ timesteps = pipe.scheduler.timesteps
446
+
447
+ for i, timestep in enumerate(timesteps):
448
+ if i % 10 == 0:
449
+ print(f" 去噪步骤 {i}/{len(timesteps)}")
450
+
451
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
452
+
453
+ with torch.no_grad():
454
+ # 正向预测(带条件)
455
+ noise_pred_pos = pipe.dit(
456
+ new_latents,
457
+ timestep=timestep_tensor,
458
+ cam_emb=camera_embedding,
459
+ latent_indices=latent_indices,
460
+ clean_latents=clean_latents,
461
+ clean_latent_indices=clean_latent_indices,
462
+ clean_latents_2x=clean_latents_2x,
463
+ clean_latent_2x_indices=clean_latent_2x_indices,
464
+ clean_latents_4x=clean_latents_4x,
465
+ clean_latent_4x_indices=clean_latent_4x_indices,
466
+ **prompt_emb_pos,
467
+ **extra_input
468
+ )
469
+
470
+ # CFG处理
471
+ if use_camera_cfg and camera_guidance_scale > 1.0:
472
+ # 无条件预测(无camera条件)
473
+ noise_pred_uncond = pipe.dit(
474
+ new_latents,
475
+ timestep=timestep_tensor,
476
+ cam_emb=camera_embedding_uncond_batch,
477
+ latent_indices=latent_indices,
478
+ clean_latents=clean_latents,
479
+ clean_latent_indices=clean_latent_indices,
480
+ clean_latents_2x=clean_latents_2x,
481
+ clean_latent_2x_indices=clean_latent_2x_indices,
482
+ clean_latents_4x=clean_latents_4x,
483
+ clean_latent_4x_indices=clean_latent_4x_indices,
484
+ **prompt_emb_pos,
485
+ **extra_input
486
+ )
487
+
488
+ # Camera CFG
489
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
490
+ else:
491
+ noise_pred = noise_pred_pos
492
+
493
+ # Text CFG
494
+ if prompt_emb_neg is not None and text_guidance_scale > 1.0:
495
+ noise_pred_text_uncond = pipe.dit(
496
+ new_latents,
497
+ timestep=timestep_tensor,
498
+ cam_emb=camera_embedding,
499
+ latent_indices=latent_indices,
500
+ clean_latents=clean_latents,
501
+ clean_latent_indices=clean_latent_indices,
502
+ clean_latents_2x=clean_latents_2x,
503
+ clean_latent_2x_indices=clean_latent_2x_indices,
504
+ clean_latents_4x=clean_latents_4x,
505
+ clean_latent_4x_indices=clean_latent_4x_indices,
506
+ **prompt_emb_neg,
507
+ **extra_input
508
+ )
509
+
510
+ # Text CFG
511
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
512
+
513
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
514
+
515
+ # 更新历史
516
+ new_latents_squeezed = new_latents.squeeze(0)
517
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
518
+
519
+ # 维护滑动窗口
520
+ if history_latents.shape[1] > max_history_frames:
521
+ first_frame = history_latents[:, 0:1, :, :]
522
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
523
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
524
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
525
+
526
+ print(f"更新后history_latents shape: {history_latents.shape}")
527
+
528
+ all_generated_frames.append(new_latents_squeezed)
529
+ total_generated += current_generation
530
+
531
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
532
+
533
+ # 10. 解码和保存
534
+ print("\n🔧 解码生成的视频...")
535
+
536
+ all_generated = torch.cat(all_generated_frames, dim=1)
537
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
538
+
539
+ print(f"最终视频shape: {final_video.shape}")
540
+
541
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
542
+
543
+ print(f"Saving video to {output_path}")
544
+
545
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
546
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
547
+ video_np = (video_np * 255).astype(np.uint8)
548
+
549
+ with imageio.get_writer(output_path, fps=20) as writer:
550
+ for frame in video_np:
551
+ writer.append_data(frame)
552
+
553
+ print(f"🔧 OpenX FramePack滑动窗口生成完成! 保存到: {output_path}")
554
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
555
+
556
+ def main():
557
+ parser = argparse.ArgumentParser(description="OpenX FramePack滑动窗口视频生成")
558
+
559
+ # 基础参数
560
+ parser.add_argument("--condition_pth", type=str,
561
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth",
562
+ help="输入编码视频路径")
563
+ parser.add_argument("--start_frame", type=int, default=0)
564
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
565
+ parser.add_argument("--frames_per_generation", type=int, default=8)
566
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
567
+ parser.add_argument("--max_history_frames", type=int, default=100)
568
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
569
+ parser.add_argument("--dit_path", type=str,
570
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step2000.ckpt",
571
+ help="训练好的模型权重路径")
572
+ parser.add_argument("--output_path", type=str,
573
+ default='openx_results/output_openx_framepack_sliding.mp4')
574
+ parser.add_argument("--prompt", type=str,
575
+ default="A video of robotic manipulation task with camera movement")
576
+ parser.add_argument("--device", type=str, default="cuda")
577
+
578
+ # CFG参数
579
+ parser.add_argument("--use_camera_cfg", action="store_true", default=True,
580
+ help="使用Camera CFG")
581
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
582
+ help="Camera guidance scale for CFG")
583
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
584
+ help="Text guidance scale for CFG")
585
+
586
+ args = parser.parse_args()
587
+
588
+ print(f"🔧 OpenX FramePack CFG生成设置:")
589
+ print(f"Camera CFG: {args.use_camera_cfg}")
590
+ if args.use_camera_cfg:
591
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
592
+ print(f"Text guidance scale: {args.text_guidance_scale}")
593
+ print(f"OpenX特有特性: camera间隔为4帧,适用于机器人操作任务")
594
+
595
+ inference_openx_framepack_sliding_window(
596
+ condition_pth_path=args.condition_pth,
597
+ dit_path=args.dit_path,
598
+ output_path=args.output_path,
599
+ start_frame=args.start_frame,
600
+ initial_condition_frames=args.initial_condition_frames,
601
+ frames_per_generation=args.frames_per_generation,
602
+ total_frames_to_generate=args.total_frames_to_generate,
603
+ max_history_frames=args.max_history_frames,
604
+ device=args.device,
605
+ prompt=args.prompt,
606
+ use_real_poses=args.use_real_poses,
607
+ # CFG参数
608
+ use_camera_cfg=args.use_camera_cfg,
609
+ camera_guidance_scale=args.camera_guidance_scale,
610
+ text_guidance_scale=args.text_guidance_scale
611
+ )
612
+
613
+ if __name__ == "__main__":
614
+ main()
scripts/infer_origin.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+ def compute_relative_pose_matrix(pose1, pose2):
15
+ """
16
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
17
+
18
+ 参数:
19
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
20
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
21
+
22
+ 返回:
23
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
24
+ """
25
+ # 分离平移向量和四元数
26
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
27
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
28
+ t2 = pose2[:3] # 第i+1帧平移
29
+ q2 = pose2[3:] # 第i+1帧四元数
30
+
31
+ # 1. 计算相对旋转矩阵 R_rel
32
+ rot1 = R.from_quat(q1) # 第i帧旋转
33
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
34
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
35
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
36
+
37
+ # 2. 计算相对平移向量 t_rel
38
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
39
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
40
+
41
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
42
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
43
+
44
+ return relative_matrix
45
+
46
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
47
+ """从pth文件加载预编码的视频数据"""
48
+ print(f"Loading encoded video from {pth_path}")
49
+
50
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
51
+ full_latents = encoded_data['latents'] # [C, T, H, W]
52
+
53
+ print(f"Full latents shape: {full_latents.shape}")
54
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
55
+
56
+ if start_frame + num_frames > full_latents.shape[1]:
57
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
58
+
59
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
60
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
61
+
62
+ return condition_latents, encoded_data
63
+
64
+
65
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
66
+ """计算相机B相对于相机A的相对位姿矩阵"""
67
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
68
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
69
+
70
+ if use_torch:
71
+ if not isinstance(pose_a, torch.Tensor):
72
+ pose_a = torch.from_numpy(pose_a).float()
73
+ if not isinstance(pose_b, torch.Tensor):
74
+ pose_b = torch.from_numpy(pose_b).float()
75
+
76
+ pose_a_inv = torch.inverse(pose_a)
77
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
78
+ else:
79
+ if not isinstance(pose_a, np.ndarray):
80
+ pose_a = np.array(pose_a, dtype=np.float32)
81
+ if not isinstance(pose_b, np.ndarray):
82
+ pose_b = np.array(pose_b, dtype=np.float32)
83
+
84
+ pose_a_inv = np.linalg.inv(pose_a)
85
+ relative_pose = np.matmul(pose_b, pose_a_inv)
86
+
87
+ return relative_pose
88
+
89
+
90
+ def replace_dit_model_in_manager():
91
+ """替换DiT模型类为MoE版本"""
92
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
93
+ from diffsynth.configs.model_config import model_loader_configs
94
+
95
+ for i, config in enumerate(model_loader_configs):
96
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
97
+
98
+ if 'wan_video_dit' in model_names:
99
+ new_model_names = []
100
+ new_model_classes = []
101
+
102
+ for name, cls in zip(model_names, model_classes):
103
+ if name == 'wan_video_dit':
104
+ new_model_names.append(name)
105
+ new_model_classes.append(WanModelMoe)
106
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
107
+ else:
108
+ new_model_names.append(name)
109
+ new_model_classes.append(cls)
110
+
111
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
112
+
113
+
114
+ def add_framepack_components(dit_model):
115
+ """添加FramePack相关组件"""
116
+ if not hasattr(dit_model, 'clean_x_embedder'):
117
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
118
+
119
+ class CleanXEmbedder(nn.Module):
120
+ def __init__(self, inner_dim):
121
+ super().__init__()
122
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
123
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
124
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
125
+
126
+ def forward(self, x, scale="1x"):
127
+ if scale == "1x":
128
+ x = x.to(self.proj.weight.dtype)
129
+ return self.proj(x)
130
+ elif scale == "2x":
131
+ x = x.to(self.proj_2x.weight.dtype)
132
+ return self.proj_2x(x)
133
+ elif scale == "4x":
134
+ x = x.to(self.proj_4x.weight.dtype)
135
+ return self.proj_4x(x)
136
+ else:
137
+ raise ValueError(f"Unsupported scale: {scale}")
138
+
139
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
140
+ model_dtype = next(dit_model.parameters()).dtype
141
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
142
+ print("✅ 添加了FramePack的clean_x_embedder组件")
143
+
144
+
145
+ def add_moe_components(dit_model, moe_config):
146
+ """🔧 添加MoE相关组件 - 修正版本"""
147
+ if not hasattr(dit_model, 'moe_config'):
148
+ dit_model.moe_config = moe_config
149
+ print("✅ 添加了MoE配置到模型")
150
+ dit_model.top_k = moe_config.get("top_k", 1)
151
+
152
+ # 为每个block动态添加MoE组件
153
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
154
+ unified_dim = moe_config.get("unified_dim", 25)
155
+ num_experts = moe_config.get("num_experts", 4)
156
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
157
+ dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
158
+ dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
159
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
160
+ dit_model.global_router = nn.Linear(unified_dim, num_experts)
161
+
162
+
163
+ for i, block in enumerate(dit_model.blocks):
164
+ # MoE网络 - 输入unified_dim,输出dim
165
+ block.moe = MultiModalMoE(
166
+ unified_dim=unified_dim,
167
+ output_dim=dim, # 输出维度匹配transformer block的dim
168
+ num_experts=moe_config.get("num_experts", 4),
169
+ top_k=moe_config.get("top_k", 2)
170
+ )
171
+
172
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
173
+
174
+
175
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True,direction="left"):
176
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
177
+ time_compression_ratio = 4
178
+
179
+ # 计算FramePack实际需要的camera帧数
180
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
181
+
182
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
183
+ print("🔧 使用真实Sekai camera数据")
184
+ cam_extrinsic = cam_data['extrinsic']
185
+
186
+ # 确保生成足够长的camera序列
187
+ max_needed_frames = max(
188
+ start_frame + current_history_length + new_frames,
189
+ framepack_needed_frames,
190
+ 30
191
+ )
192
+
193
+ print(f"🔧 计算Sekai camera序列长度:")
194
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
195
+ print(f" - FramePack需求: {framepack_needed_frames}")
196
+ print(f" - 最终生成: {max_needed_frames}")
197
+
198
+ relative_poses = []
199
+ for i in range(max_needed_frames):
200
+ # 计算当前帧在原始序列中的位置
201
+ frame_idx = i * time_compression_ratio
202
+ next_frame_idx = frame_idx + time_compression_ratio
203
+
204
+ if next_frame_idx < len(cam_extrinsic):
205
+ cam_prev = cam_extrinsic[frame_idx]
206
+ cam_next = cam_extrinsic[next_frame_idx]
207
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
208
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
209
+ else:
210
+ # 超出范围,使用零运动
211
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
212
+ relative_poses.append(torch.zeros(3, 4))
213
+
214
+ pose_embedding = torch.stack(relative_poses, dim=0)
215
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
216
+
217
+ # 创建对应长度的mask序列
218
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
219
+ # 从start_frame到current_history_length标记为condition
220
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
221
+ mask[start_frame:condition_end] = 1.0
222
+
223
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
224
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
225
+ return camera_embedding.to(torch.bfloat16)
226
+
227
+ else:
228
+ if direction=="left":
229
+ print("-----Left-------")
230
+
231
+ max_needed_frames = max(
232
+ start_frame + current_history_length + new_frames,
233
+ framepack_needed_frames,
234
+ 30
235
+ )
236
+
237
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
238
+ relative_poses = []
239
+ for i in range(max_needed_frames):
240
+ # 持续左转运动模式
241
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
242
+ forward_speed = 0.05 # 每帧前进距离
243
+
244
+ pose = np.eye(4, dtype=np.float32)
245
+
246
+ # 旋转矩阵(绕Y轴左转)
247
+ cos_yaw = np.cos(yaw_per_frame)
248
+ sin_yaw = np.sin(yaw_per_frame)
249
+
250
+ pose[0, 0] = cos_yaw
251
+ pose[0, 2] = sin_yaw
252
+ pose[2, 0] = -sin_yaw
253
+ pose[2, 2] = cos_yaw
254
+
255
+ # 平移(在旋转后的局部坐标系中前进)
256
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
257
+
258
+ # 添加轻微的向心运动,模拟圆形轨迹
259
+ radius_drift = 0.002 # 向圆心的轻微漂移
260
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
261
+
262
+ relative_pose = pose[:3, :]
263
+ relative_poses.append(torch.as_tensor(relative_pose))
264
+
265
+ pose_embedding = torch.stack(relative_poses, dim=0)
266
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
267
+
268
+ # 创建对应长度的mask序列
269
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
270
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
271
+ mask[start_frame:condition_end] = 1.0
272
+
273
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
274
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
275
+ return camera_embedding.to(torch.bfloat16)
276
+ elif direction=="right":
277
+ print("------------Right----------")
278
+
279
+ max_needed_frames = max(
280
+ start_frame + current_history_length + new_frames,
281
+ framepack_needed_frames,
282
+ 30
283
+ )
284
+
285
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
286
+ relative_poses = []
287
+ for i in range(max_needed_frames):
288
+ # 持续左转运动模式
289
+ yaw_per_frame = -0.00 # 每帧左转(正角度表示左转)
290
+ forward_speed = 0.1 # 每帧前进距离
291
+
292
+ pose = np.eye(4, dtype=np.float32)
293
+
294
+ # 旋转矩阵(绕Y轴左转)
295
+ cos_yaw = np.cos(yaw_per_frame)
296
+ sin_yaw = np.sin(yaw_per_frame)
297
+
298
+ pose[0, 0] = cos_yaw
299
+ pose[0, 2] = sin_yaw
300
+ pose[2, 0] = -sin_yaw
301
+ pose[2, 2] = cos_yaw
302
+
303
+ # 平移(在旋转后的局部坐标系中前进)
304
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
305
+
306
+ # 添加轻微的向心运动,模拟圆形轨迹
307
+ radius_drift = 0.000 # 向圆心的轻微漂移
308
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
309
+
310
+ relative_pose = pose[:3, :]
311
+ relative_poses.append(torch.as_tensor(relative_pose))
312
+
313
+ pose_embedding = torch.stack(relative_poses, dim=0)
314
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
315
+
316
+ # 创建对应长度的mask序列
317
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
318
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
319
+ mask[start_frame:condition_end] = 1.0
320
+
321
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
322
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
323
+ return camera_embedding.to(torch.bfloat16)
324
+
325
+
326
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
327
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
328
+ time_compression_ratio = 4
329
+
330
+ # 计算FramePack实际需要的camera帧数
331
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
332
+
333
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
334
+ print("🔧 使用OpenX真实camera数据")
335
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
336
+
337
+ # 确保生成足够长的camera序列
338
+ max_needed_frames = max(
339
+ start_frame + current_history_length + new_frames,
340
+ framepack_needed_frames,
341
+ 30
342
+ )
343
+
344
+ print(f"🔧 计算OpenX camera序列长度:")
345
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
346
+ print(f" - FramePack需求: {framepack_needed_frames}")
347
+ print(f" - 最终生成: {max_needed_frames}")
348
+
349
+ relative_poses = []
350
+ for i in range(max_needed_frames):
351
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
352
+ frame_idx = i * time_compression_ratio
353
+ next_frame_idx = frame_idx + time_compression_ratio
354
+
355
+ if next_frame_idx < len(cam_extrinsic):
356
+ cam_prev = cam_extrinsic[frame_idx]
357
+ cam_next = cam_extrinsic[next_frame_idx]
358
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
359
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
360
+ else:
361
+ # 超出范围,使用零运动
362
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
363
+ relative_poses.append(torch.zeros(3, 4))
364
+
365
+ pose_embedding = torch.stack(relative_poses, dim=0)
366
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
367
+
368
+ # 创建对应长度的mask序列
369
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
370
+ # 从start_frame到current_history_length标记为condition
371
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
372
+ mask[start_frame:condition_end] = 1.0
373
+
374
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
375
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
376
+ return camera_embedding.to(torch.bfloat16)
377
+
378
+ else:
379
+ print("🔧 使用OpenX合成camera数据")
380
+
381
+ max_needed_frames = max(
382
+ start_frame + current_history_length + new_frames,
383
+ framepack_needed_frames,
384
+ 30
385
+ )
386
+
387
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
388
+ relative_poses = []
389
+ for i in range(max_needed_frames):
390
+ # OpenX机器人操作运动模式 - 较小的运动幅度
391
+ # 模拟机器人手臂的精细操作运动
392
+ roll_per_frame = 0.02 # 轻微翻滚
393
+ pitch_per_frame = 0.01 # 轻微俯仰
394
+ yaw_per_frame = 0.015 # 轻微偏航
395
+ forward_speed = 0.003 # 较慢的前进速度
396
+
397
+ pose = np.eye(4, dtype=np.float32)
398
+
399
+ # 复合旋转 - 模拟机器人手臂的复杂运动
400
+ # 绕X轴旋转(roll)
401
+ cos_roll = np.cos(roll_per_frame)
402
+ sin_roll = np.sin(roll_per_frame)
403
+ # 绕Y轴旋转(pitch)
404
+ cos_pitch = np.cos(pitch_per_frame)
405
+ sin_pitch = np.sin(pitch_per_frame)
406
+ # 绕Z轴旋转(yaw)
407
+ cos_yaw = np.cos(yaw_per_frame)
408
+ sin_yaw = np.sin(yaw_per_frame)
409
+
410
+ # 简化的复合旋转矩阵(ZYX顺序)
411
+ pose[0, 0] = cos_yaw * cos_pitch
412
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
413
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
414
+ pose[1, 0] = sin_yaw * cos_pitch
415
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
416
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
417
+ pose[2, 0] = -sin_pitch
418
+ pose[2, 1] = cos_pitch * sin_roll
419
+ pose[2, 2] = cos_pitch * cos_roll
420
+
421
+ # 平移 - 模拟机器人操作的精细移动
422
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
423
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
424
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
425
+
426
+ relative_pose = pose[:3, :]
427
+ relative_poses.append(torch.as_tensor(relative_pose))
428
+
429
+ pose_embedding = torch.stack(relative_poses, dim=0)
430
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
431
+
432
+ # 创建对应长度的mask序列
433
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
434
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
435
+ mask[start_frame:condition_end] = 1.0
436
+
437
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
438
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
439
+ return camera_embedding.to(torch.bfloat16)
440
+
441
+
442
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
443
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
444
+ time_compression_ratio = 4
445
+
446
+ # 计算FramePack实际需要的camera��数
447
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
448
+
449
+ if scene_info is not None and 'keyframe_poses' in scene_info:
450
+ print("🔧 使用NuScenes真实pose数据")
451
+ keyframe_poses = scene_info['keyframe_poses']
452
+
453
+ if len(keyframe_poses) == 0:
454
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
455
+ max_needed_frames = max(framepack_needed_frames, 30)
456
+
457
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
458
+
459
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
460
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
461
+ mask[start_frame:condition_end] = 1.0
462
+
463
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
464
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
465
+ return camera_embedding.to(torch.bfloat16)
466
+
467
+ # 使用第一个pose作为参考
468
+ reference_pose = keyframe_poses[0]
469
+
470
+ max_needed_frames = max(framepack_needed_frames, 30)
471
+
472
+ pose_vecs = []
473
+ for i in range(max_needed_frames):
474
+ if i < len(keyframe_poses):
475
+ current_pose = keyframe_poses[i]
476
+
477
+ # 计算相对位移
478
+ translation = torch.tensor(
479
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
480
+ dtype=torch.float32
481
+ )
482
+
483
+ # 计算相对旋转(简化版本)
484
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
485
+
486
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
487
+ else:
488
+ # 超出范围,使用零pose
489
+ pose_vec = torch.cat([
490
+ torch.zeros(3, dtype=torch.float32),
491
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
492
+ ], dim=0) # [7D]
493
+
494
+ pose_vecs.append(pose_vec)
495
+
496
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
497
+
498
+ # 创建mask
499
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
500
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
501
+ mask[start_frame:condition_end] = 1.0
502
+
503
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
504
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
505
+ return camera_embedding.to(torch.bfloat16)
506
+
507
+ else:
508
+ print("🔧 使用NuScenes合成pose数据")
509
+ max_needed_frames = max(framepack_needed_frames, 30)
510
+
511
+ # 创建合成运动序列
512
+ pose_vecs = []
513
+ for i in range(max_needed_frames):
514
+ # 左转运动模式 - 类似城市驾驶中的左转弯
515
+ angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
516
+ radius = 15.0 # 较大的转弯半径,更符合汽车转弯
517
+
518
+ # 计算圆弧轨迹上的位置
519
+ x = radius * np.sin(angle)
520
+ y = 0.0 # 保持水平面运动
521
+ z = radius * (1 - np.cos(angle))
522
+
523
+ translation = torch.tensor([x, y, z], dtype=torch.float32)
524
+
525
+ # 车辆朝向 - 始终沿着轨迹切线方向
526
+ yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
527
+ # 四元数表示绕Y轴的旋转
528
+ rotation = torch.tensor([
529
+ np.cos(yaw/2), # w (实部)
530
+ 0.0, # x
531
+ 0.0, # y
532
+ np.sin(yaw/2) # z (虚部,绕Y轴)
533
+ ], dtype=torch.float32)
534
+
535
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
536
+ pose_vecs.append(pose_vec)
537
+
538
+ pose_sequence = torch.stack(pose_vecs, dim=0)
539
+
540
+ # 创建mask
541
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
542
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
543
+ mask[start_frame:condition_end] = 1.0
544
+
545
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
546
+ print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
547
+ return camera_embedding.to(torch.bfloat16)
548
+
549
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
550
+ """FramePack滑动窗口机制 - MoE版本"""
551
+ # history_latents: [C, T, H, W] 当前的历史latents
552
+ C, T, H, W = history_latents.shape
553
+
554
+ # 固定索引结构(这决定了需要的camera帧数)
555
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
556
+ indices = torch.arange(0, total_indices_length)
557
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
558
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
559
+ indices.split(split_sizes, dim=0)
560
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
561
+
562
+ # 检查camera长度是否足够
563
+ if camera_embedding_full.shape[0] < total_indices_length:
564
+ shortage = total_indices_length - camera_embedding_full.shape[0]
565
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
566
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
567
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
568
+
569
+ # 从完整camera序列中选取对应部分
570
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
571
+
572
+ # 根据当前history length重新设置mask
573
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
574
+
575
+ # 设置condition mask:前19帧根据实际历史长度决定
576
+ if T > 0:
577
+ available_frames = min(T, 19)
578
+ start_pos = 19 - available_frames
579
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
580
+
581
+ print(f"🔧 MoE Camera mask更新:")
582
+ print(f" - 历史帧数: {T}")
583
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
584
+ print(f" - 模态类型: {modality_type}")
585
+
586
+ # 处理latents
587
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
588
+
589
+ if T > 0:
590
+ available_frames = min(T, 19)
591
+ start_pos = 19 - available_frames
592
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
593
+
594
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
595
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
596
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
597
+
598
+ if T > 0:
599
+ start_latent = history_latents[:, 0:1, :, :]
600
+ else:
601
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
602
+
603
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
604
+
605
+ return {
606
+ 'latent_indices': latent_indices,
607
+ 'clean_latents': clean_latents,
608
+ 'clean_latents_2x': clean_latents_2x,
609
+ 'clean_latents_4x': clean_latents_4x,
610
+ 'clean_latent_indices': clean_latent_indices,
611
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
612
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
613
+ 'camera_embedding': combined_camera,
614
+ 'modality_type': modality_type, # 新增模态类型信息
615
+ 'current_length': T,
616
+ 'next_length': T + target_frames_to_generate
617
+ }
618
+
619
+
620
+ def inference_moe_framepack_sliding_window(
621
+ condition_pth_path,
622
+ dit_path,
623
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
624
+ start_frame=0,
625
+ initial_condition_frames=8,
626
+ frames_per_generation=4,
627
+ total_frames_to_generate=32,
628
+ max_history_frames=49,
629
+ device="cuda",
630
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
631
+ modality_type="sekai", # "sekai" 或 "nuscenes"
632
+ use_real_poses=True,
633
+ scene_info_path=None, # 对于NuScenes数据集
634
+ # CFG参数
635
+ use_camera_cfg=True,
636
+ camera_guidance_scale=2.0,
637
+ text_guidance_scale=1.0,
638
+ # MoE参数
639
+ moe_num_experts=4,
640
+ moe_top_k=2,
641
+ moe_hidden_dim=None,
642
+ direction="left",
643
+ use_gt_prompt=True
644
+ ):
645
+ """
646
+ MoE FramePack滑动窗口视频生成 - 支持多模态
647
+ """
648
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
649
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
650
+ print(f"模态类型: {modality_type}")
651
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
652
+ print(f"Text guidance scale: {text_guidance_scale}")
653
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
654
+
655
+ # 1. 模型初始化
656
+ replace_dit_model_in_manager()
657
+
658
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
659
+ model_manager.load_models([
660
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
661
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
662
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
663
+ ])
664
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
665
+
666
+ # 2. 添加传统camera编码器(兼容性)
667
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
668
+ for block in pipe.dit.blocks:
669
+ block.cam_encoder = nn.Linear(13, dim)
670
+ block.projector = nn.Linear(dim, dim)
671
+ block.cam_encoder.weight.data.zero_()
672
+ block.cam_encoder.bias.data.zero_()
673
+ block.projector.weight = nn.Parameter(torch.eye(dim))
674
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
675
+
676
+ # 3. 添加FramePack组件
677
+ add_framepack_components(pipe.dit)
678
+
679
+ # 4. 添加MoE组件
680
+ moe_config = {
681
+ "num_experts": moe_num_experts,
682
+ "top_k": moe_top_k,
683
+ "hidden_dim": moe_hidden_dim or dim * 2,
684
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
685
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
686
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
687
+ }
688
+ add_moe_components(pipe.dit, moe_config)
689
+
690
+ # 5. 加载训练好的权重
691
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
692
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
693
+ pipe = pipe.to(device)
694
+ model_dtype = next(pipe.dit.parameters()).dtype
695
+
696
+ if hasattr(pipe.dit, 'clean_x_embedder'):
697
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
698
+
699
+ pipe.scheduler.set_timesteps(50)
700
+
701
+ # 6. 加载初始条件
702
+ print("Loading initial condition frames...")
703
+ initial_latents, encoded_data = load_encoded_video_from_pth(
704
+ condition_pth_path,
705
+ start_frame=start_frame,
706
+ num_frames=initial_condition_frames
707
+ )
708
+
709
+ # 空间裁剪
710
+ target_height, target_width = 60, 104
711
+ C, T, H, W = initial_latents.shape
712
+
713
+ if H > target_height or W > target_width:
714
+ h_start = (H - target_height) // 2
715
+ w_start = (W - target_width) // 2
716
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
717
+ H, W = target_height, target_width
718
+
719
+ history_latents = initial_latents.to(device, dtype=model_dtype)
720
+
721
+ print(f"初始history_latents shape: {history_latents.shape}")
722
+
723
+ # 7. 编码prompt - 支持CFG
724
+ if use_gt_prompt and 'prompt_emb' in encoded_data:
725
+ print("✅ 使用预编码的GT prompt embedding")
726
+ prompt_emb_pos = encoded_data['prompt_emb']
727
+ # 将prompt_emb移到正确的设备和数据类型
728
+ if 'context' in prompt_emb_pos:
729
+ prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
730
+ if 'context_mask' in prompt_emb_pos:
731
+ prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
732
+
733
+ # 如果使用Text CFG,生成负向prompt
734
+ if text_guidance_scale > 1.0:
735
+ prompt_emb_neg = pipe.encode_prompt("")
736
+ print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
737
+ else:
738
+ prompt_emb_neg = None
739
+ print("不使用Text CFG")
740
+
741
+ # 🔧 打印GT prompt文本(如果有)
742
+ if 'prompt' in encoded_data['prompt_emb']:
743
+ gt_prompt_text = encoded_data['prompt_emb']['prompt']
744
+ print(f"📝 GT Prompt文本: {gt_prompt_text}")
745
+ else:
746
+ # 使用传入的prompt参数重新编码
747
+ print(f"🔄 重新编码prompt: {prompt}")
748
+ if text_guidance_scale > 1.0:
749
+ prompt_emb_pos = pipe.encode_prompt(prompt)
750
+ prompt_emb_neg = pipe.encode_prompt("")
751
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
752
+ else:
753
+ prompt_emb_pos = pipe.encode_prompt(prompt)
754
+ prompt_emb_neg = None
755
+ print("不使用Text CFG")
756
+
757
+ # 8. 加载场景信息(对于NuScenes)
758
+ scene_info = None
759
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
760
+ with open(scene_info_path, 'r') as f:
761
+ scene_info = json.load(f)
762
+ print(f"加载NuScenes场景信息: {scene_info_path}")
763
+
764
+ # 9. 预生成完整的camera embedding序列
765
+ if modality_type == "sekai":
766
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
767
+ encoded_data.get('cam_emb', None),
768
+ 0,
769
+ max_history_frames,
770
+ 0,
771
+ 0,
772
+ use_real_poses=use_real_poses,
773
+ direction=direction
774
+ ).to(device, dtype=model_dtype)
775
+ elif modality_type == "nuscenes":
776
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
777
+ scene_info,
778
+ 0,
779
+ max_history_frames,
780
+ 0
781
+ ).to(device, dtype=model_dtype)
782
+ elif modality_type == "openx":
783
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
784
+ encoded_data,
785
+ 0,
786
+ max_history_frames,
787
+ 0,
788
+ use_real_poses=use_real_poses
789
+ ).to(device, dtype=model_dtype)
790
+ else:
791
+ raise ValueError(f"不支持的模态类型: {modality_type}")
792
+
793
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
794
+
795
+ # 10. 为Camera CFG创建无条件的camera embedding
796
+ if use_camera_cfg:
797
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
798
+ print(f"创建无条件camera embedding用于CFG")
799
+
800
+ # 11. 滑动窗口生成循环
801
+ total_generated = 0
802
+ all_generated_frames = []
803
+
804
+ while total_generated < total_frames_to_generate:
805
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
806
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
807
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
808
+
809
+ # FramePack数据准备 - MoE版本
810
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
811
+ history_latents,
812
+ current_generation,
813
+ camera_embedding_full,
814
+ start_frame,
815
+ modality_type,
816
+ max_history_frames
817
+ )
818
+
819
+ # 准备输入
820
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
821
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
822
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
823
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
824
+
825
+ # 准备modality_inputs
826
+ modality_inputs = {modality_type: camera_embedding}
827
+
828
+ # 为CFG准备无条件camera embedding
829
+ if use_camera_cfg:
830
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
831
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
832
+
833
+ # 索引处理
834
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
835
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
836
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
837
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
838
+
839
+ # 初始化要生成的latents
840
+ new_latents = torch.randn(
841
+ 1, C, current_generation, H, W,
842
+ device=device, dtype=model_dtype
843
+ )
844
+
845
+ extra_input = pipe.prepare_extra_input(new_latents)
846
+
847
+ print(f"Camera embedding shape: {camera_embedding.shape}")
848
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
849
+
850
+ # 去噪循环 - 支持CFG
851
+ timesteps = pipe.scheduler.timesteps
852
+
853
+ for i, timestep in enumerate(timesteps):
854
+ if i % 10 == 0:
855
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
856
+
857
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
858
+
859
+ with torch.no_grad():
860
+ # CFG推理
861
+ if use_camera_cfg and camera_guidance_scale > 1.0:
862
+ # 条件预测(有camera)
863
+ noise_pred_cond, moe_loess = pipe.dit(
864
+ new_latents,
865
+ timestep=timestep_tensor,
866
+ cam_emb=camera_embedding,
867
+ modality_inputs=modality_inputs, # MoE模态输入
868
+ latent_indices=latent_indices,
869
+ clean_latents=clean_latents,
870
+ clean_latent_indices=clean_latent_indices,
871
+ clean_latents_2x=clean_latents_2x,
872
+ clean_latent_2x_indices=clean_latent_2x_indices,
873
+ clean_latents_4x=clean_latents_4x,
874
+ clean_latent_4x_indices=clean_latent_4x_indices,
875
+ **prompt_emb_pos,
876
+ **extra_input
877
+ )
878
+
879
+ # 无条件预测(无camera)
880
+ noise_pred_uncond, moe_loess = pipe.dit(
881
+ new_latents,
882
+ timestep=timestep_tensor,
883
+ cam_emb=camera_embedding_uncond_batch,
884
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
885
+ latent_indices=latent_indices,
886
+ clean_latents=clean_latents,
887
+ clean_latent_indices=clean_latent_indices,
888
+ clean_latents_2x=clean_latents_2x,
889
+ clean_latent_2x_indices=clean_latent_2x_indices,
890
+ clean_latents_4x=clean_latents_4x,
891
+ clean_latent_4x_indices=clean_latent_4x_indices,
892
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
893
+ **extra_input
894
+ )
895
+
896
+ # Camera CFG
897
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
898
+
899
+ # 如果同时使用Text CFG
900
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
901
+ noise_pred_text_uncond, moe_loess = pipe.dit(
902
+ new_latents,
903
+ timestep=timestep_tensor,
904
+ cam_emb=camera_embedding,
905
+ modality_inputs=modality_inputs,
906
+ latent_indices=latent_indices,
907
+ clean_latents=clean_latents,
908
+ clean_latent_indices=clean_latent_indices,
909
+ clean_latents_2x=clean_latents_2x,
910
+ clean_latent_2x_indices=clean_latent_2x_indices,
911
+ clean_latents_4x=clean_latents_4x,
912
+ clean_latent_4x_indices=clean_latent_4x_indices,
913
+ **prompt_emb_neg,
914
+ **extra_input
915
+ )
916
+
917
+ # 应用Text CFG到已经应用Camera CFG的结果
918
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
919
+
920
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
921
+ # 只使用Text CFG
922
+ noise_pred_cond, moe_loess = pipe.dit(
923
+ new_latents,
924
+ timestep=timestep_tensor,
925
+ cam_emb=camera_embedding,
926
+ modality_inputs=modality_inputs,
927
+ latent_indices=latent_indices,
928
+ clean_latents=clean_latents,
929
+ clean_latent_indices=clean_latent_indices,
930
+ clean_latents_2x=clean_latents_2x,
931
+ clean_latent_2x_indices=clean_latent_2x_indices,
932
+ clean_latents_4x=clean_latents_4x,
933
+ clean_latent_4x_indices=clean_latent_4x_indices,
934
+ **prompt_emb_pos,
935
+ **extra_input
936
+ )
937
+
938
+ noise_pred_uncond, moe_loess= pipe.dit(
939
+ new_latents,
940
+ timestep=timestep_tensor,
941
+ cam_emb=camera_embedding,
942
+ modality_inputs=modality_inputs,
943
+ latent_indices=latent_indices,
944
+ clean_latents=clean_latents,
945
+ clean_latent_indices=clean_latent_indices,
946
+ clean_latents_2x=clean_latents_2x,
947
+ clean_latent_2x_indices=clean_latent_2x_indices,
948
+ clean_latents_4x=clean_latents_4x,
949
+ clean_latent_4x_indices=clean_latent_4x_indices,
950
+ **prompt_emb_neg,
951
+ **extra_input
952
+ )
953
+
954
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
955
+
956
+ else:
957
+ # 标准推理(无CFG)
958
+ noise_pred, moe_loess = pipe.dit(
959
+ new_latents,
960
+ timestep=timestep_tensor,
961
+ cam_emb=camera_embedding,
962
+ modality_inputs=modality_inputs, # MoE模态输入
963
+ latent_indices=latent_indices,
964
+ clean_latents=clean_latents,
965
+ clean_latent_indices=clean_latent_indices,
966
+ clean_latents_2x=clean_latents_2x,
967
+ clean_latent_2x_indices=clean_latent_2x_indices,
968
+ clean_latents_4x=clean_latents_4x,
969
+ clean_latent_4x_indices=clean_latent_4x_indices,
970
+ **prompt_emb_pos,
971
+ **extra_input
972
+ )
973
+
974
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
975
+
976
+ # 更新历史
977
+ new_latents_squeezed = new_latents.squeeze(0)
978
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
979
+
980
+ # 维护滑动窗口
981
+ if history_latents.shape[1] > max_history_frames:
982
+ first_frame = history_latents[:, 0:1, :, :]
983
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
984
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
985
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
986
+
987
+ print(f"更新后history_latents shape: {history_latents.shape}")
988
+
989
+ all_generated_frames.append(new_latents_squeezed)
990
+ total_generated += current_generation
991
+
992
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
993
+
994
+ # 12. 解码和保存
995
+ print("\n🔧 解码生成的视频...")
996
+
997
+ all_generated = torch.cat(all_generated_frames, dim=1)
998
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
999
+
1000
+ print(f"最终视频shape: {final_video.shape}")
1001
+
1002
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
1003
+
1004
+ print(f"Saving video to {output_path}")
1005
+
1006
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
1007
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
1008
+ video_np = (video_np * 255).astype(np.uint8)
1009
+
1010
+ with imageio.get_writer(output_path, fps=20) as writer:
1011
+ for frame in video_np:
1012
+ writer.append_data(frame)
1013
+
1014
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
1015
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
1016
+ print(f"使用模态: {modality_type}")
1017
+
1018
+
1019
+ def main():
1020
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
1021
+
1022
+ # 基础参数
1023
+ parser.add_argument("--condition_pth", type=str,
1024
+ #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone/00500210001_0012150_0012450/encoded_video.pth")
1025
+ default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
1026
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
1027
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
1028
+ parser.add_argument("--start_frame", type=int, default=0)
1029
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
1030
+ parser.add_argument("--frames_per_generation", type=int, default=8)
1031
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
1032
+ parser.add_argument("--max_history_frames", type=int, default=100)
1033
+ parser.add_argument("--use_real_poses", default=False)
1034
+ parser.add_argument("--dit_path", type=str,
1035
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt")
1036
+ parser.add_argument("--output_path", type=str,
1037
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
1038
+ parser.add_argument("--prompt", type=str,
1039
+ default="A car is driving")
1040
+ parser.add_argument("--device", type=str, default="cuda")
1041
+
1042
+ # 模态类型参数
1043
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="nuscenes",
1044
+ help="模态类型:sekai 或 nuscenes 或 openx")
1045
+ parser.add_argument("--scene_info_path", type=str, default=None,
1046
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
1047
+
1048
+ # CFG参数
1049
+ parser.add_argument("--use_camera_cfg", default=False,
1050
+ help="使用Camera CFG")
1051
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
1052
+ help="Camera guidance scale for CFG")
1053
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
1054
+ help="Text guidance scale for CFG")
1055
+
1056
+ # MoE参数
1057
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
1058
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
1059
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
1060
+ parser.add_argument("--direction", type=str, default="left")
1061
+ parser.add_argument("--use_gt_prompt", action="store_true", default=False,
1062
+ help="使用数据集中的ground truth prompt embedding")
1063
+
1064
+ args = parser.parse_args()
1065
+
1066
+ print(f"🔧 MoE FramePack CFG生成设置:")
1067
+ print(f"模态类型: {args.modality_type}")
1068
+ print(f"Camera CFG: {args.use_camera_cfg}")
1069
+ if args.use_camera_cfg:
1070
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
1071
+ print(f"使用GT Prompt: {args.use_gt_prompt}")
1072
+ print(f"Text guidance scale: {args.text_guidance_scale}")
1073
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
1074
+ print(f"DiT{args.dit_path}")
1075
+
1076
+ # 验证NuScenes参数
1077
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
1078
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
1079
+
1080
+ inference_moe_framepack_sliding_window(
1081
+ condition_pth_path=args.condition_pth,
1082
+ dit_path=args.dit_path,
1083
+ output_path=args.output_path,
1084
+ start_frame=args.start_frame,
1085
+ initial_condition_frames=args.initial_condition_frames,
1086
+ frames_per_generation=args.frames_per_generation,
1087
+ total_frames_to_generate=args.total_frames_to_generate,
1088
+ max_history_frames=args.max_history_frames,
1089
+ device=args.device,
1090
+ prompt=args.prompt,
1091
+ modality_type=args.modality_type,
1092
+ use_real_poses=args.use_real_poses,
1093
+ scene_info_path=args.scene_info_path,
1094
+ # CFG参数
1095
+ use_camera_cfg=args.use_camera_cfg,
1096
+ camera_guidance_scale=args.camera_guidance_scale,
1097
+ text_guidance_scale=args.text_guidance_scale,
1098
+ # MoE参数
1099
+ moe_num_experts=args.moe_num_experts,
1100
+ moe_top_k=args.moe_top_k,
1101
+ moe_hidden_dim=args.moe_hidden_dim,
1102
+ direction=args.direction,
1103
+ use_gt_prompt=args.use_gt_prompt
1104
+ )
1105
+
1106
+
1107
+ if __name__ == "__main__":
1108
+ main()
scripts/infer_recam.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData
5
+ import torch, os, imageio, argparse
6
+ from torchvision.transforms import v2
7
+ from einops import rearrange
8
+ import pandas as pd
9
+ import torchvision
10
+ from PIL import Image
11
+ import numpy as np
12
+ import json
13
+
14
+ class Camera(object):
15
+ def __init__(self, c2w):
16
+ c2w_mat = np.array(c2w).reshape(4, 4)
17
+ self.c2w_mat = c2w_mat
18
+ self.w2c_mat = np.linalg.inv(c2w_mat)
19
+
20
+ class TextVideoCameraDataset(torch.utils.data.Dataset):
21
+ def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, condition_frames=40, target_frames=20):
22
+ metadata = pd.read_csv(metadata_path)
23
+ self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]]
24
+ self.text = metadata["text"].to_list()
25
+
26
+ self.max_num_frames = max_num_frames
27
+ self.frame_interval = frame_interval
28
+ self.num_frames = num_frames
29
+ self.height = height
30
+ self.width = width
31
+ self.is_i2v = is_i2v
32
+ self.args = args
33
+ self.cam_type = self.args.cam_type
34
+
35
+ # 🔧 新增:保存帧数配置
36
+ self.condition_frames = condition_frames
37
+ self.target_frames = target_frames
38
+
39
+ self.frame_process = v2.Compose([
40
+ v2.CenterCrop(size=(height, width)),
41
+ v2.Resize(size=(height, width), antialias=True),
42
+ v2.ToTensor(),
43
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
44
+ ])
45
+
46
+ def crop_and_resize(self, image):
47
+ width, height = image.size
48
+ scale = max(self.width / width, self.height / height)
49
+ image = torchvision.transforms.functional.resize(
50
+ image,
51
+ (round(height*scale), round(width*scale)),
52
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
53
+ )
54
+ return image
55
+
56
+ def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
57
+ reader = imageio.get_reader(file_path)
58
+ if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
59
+ reader.close()
60
+ return None
61
+
62
+ frames = []
63
+ first_frame = None
64
+ for frame_id in range(num_frames):
65
+ frame = reader.get_data(start_frame_id + frame_id * interval)
66
+ frame = Image.fromarray(frame)
67
+ frame = self.crop_and_resize(frame)
68
+ if first_frame is None:
69
+ first_frame = np.array(frame)
70
+ frame = frame_process(frame)
71
+ frames.append(frame)
72
+ reader.close()
73
+
74
+ frames = torch.stack(frames, dim=0)
75
+ frames = rearrange(frames, "T C H W -> C T H W")
76
+
77
+ if self.is_i2v:
78
+ return frames, first_frame
79
+ else:
80
+ return frames
81
+
82
+ def is_image(self, file_path):
83
+ file_ext_name = file_path.split(".")[-1]
84
+ if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
85
+ return True
86
+ return False
87
+
88
+ def load_video(self, file_path):
89
+ start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
90
+ frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
91
+ return frames
92
+
93
+ def parse_matrix(self, matrix_str):
94
+ rows = matrix_str.strip().split('] [')
95
+ matrix = []
96
+ for row in rows:
97
+ row = row.replace('[', '').replace(']', '')
98
+ matrix.append(list(map(float, row.split())))
99
+ return np.array(matrix)
100
+
101
+ def get_relative_pose(self, cam_params):
102
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
103
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
104
+
105
+ cam_to_origin = 0
106
+ target_cam_c2w = np.array([
107
+ [1, 0, 0, 0],
108
+ [0, 1, 0, -cam_to_origin],
109
+ [0, 0, 1, 0],
110
+ [0, 0, 0, 1]
111
+ ])
112
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
113
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
114
+ ret_poses = np.array(ret_poses, dtype=np.float32)
115
+ return ret_poses
116
+
117
+ def __getitem__(self, data_id):
118
+ text = self.text[data_id]
119
+ path = self.path[data_id]
120
+ video = self.load_video(path)
121
+ if video is None:
122
+ raise ValueError(f"{path} is not a valid video.")
123
+ num_frames = video.shape[1]
124
+ assert num_frames == 81
125
+ data = {"text": text, "video": video, "path": path}
126
+
127
+ # load camera
128
+ tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json"
129
+ with open(tgt_camera_path, 'r') as file:
130
+ cam_data = json.load(file)
131
+
132
+ # 🔧 修改:生成target_frames长度的相机轨迹
133
+ cam_idx = np.linspace(0, 80, self.target_frames, dtype=int).tolist() # 改为target_frames长度
134
+ traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx]
135
+ traj = np.stack(traj).transpose(0, 2, 1)
136
+ c2ws = []
137
+ for c2w in traj:
138
+ c2w = c2w[:, [1, 2, 0, 3]]
139
+ c2w[:3, 1] *= -1.
140
+ c2w[:3, 3] /= 100
141
+ c2ws.append(c2w)
142
+ tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
143
+ relative_poses = []
144
+ for i in range(len(tgt_cam_params)):
145
+ relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
146
+ relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
147
+ pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4]
148
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12]
149
+ data['camera'] = pose_embedding.to(torch.bfloat16)
150
+ return data
151
+
152
+ def __len__(self):
153
+ return len(self.path)
154
+
155
+ def parse_args():
156
+ parser = argparse.ArgumentParser(description="ReCamMaster Inference")
157
+ parser.add_argument(
158
+ "--dataset_path",
159
+ type=str,
160
+ default="./example_test_data",
161
+ help="The path of the Dataset.",
162
+ )
163
+ parser.add_argument(
164
+ "--ckpt_path",
165
+ type=str,
166
+ default="/share_zhuyixuan05/zhuyixuan05/recam_future_checkpoint/step1000.ckpt",
167
+ help="Path to save the model.",
168
+ )
169
+ parser.add_argument(
170
+ "--output_dir",
171
+ type=str,
172
+ default="./results",
173
+ help="Path to save the results.",
174
+ )
175
+ parser.add_argument(
176
+ "--dataloader_num_workers",
177
+ type=int,
178
+ default=1,
179
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
180
+ )
181
+ parser.add_argument(
182
+ "--cam_type",
183
+ type=str,
184
+ default=1,
185
+ )
186
+ parser.add_argument(
187
+ "--cfg_scale",
188
+ type=float,
189
+ default=5.0,
190
+ )
191
+ # 🔧 新增:condition和target帧数参数
192
+ parser.add_argument(
193
+ "--condition_frames",
194
+ type=int,
195
+ default=15,
196
+ help="Number of condition frames",
197
+ )
198
+ parser.add_argument(
199
+ "--target_frames",
200
+ type=int,
201
+ default=15,
202
+ help="Number of target frames to generate",
203
+ )
204
+ args = parser.parse_args()
205
+ return args
206
+
207
+ if __name__ == '__main__':
208
+ args = parse_args()
209
+
210
+ # 1. Load Wan2.1 pre-trained models
211
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
212
+ model_manager.load_models([
213
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
214
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
215
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
216
+ ])
217
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
218
+
219
+ # 2. Initialize additional modules introduced in ReCamMaster
220
+ dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0]
221
+ for block in pipe.dit.blocks:
222
+ block.cam_encoder = nn.Linear(12, dim)
223
+ block.projector = nn.Linear(dim, dim)
224
+ block.cam_encoder.weight.data.zero_()
225
+ block.cam_encoder.bias.data.zero_()
226
+ block.projector.weight = nn.Parameter(torch.eye(dim))
227
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
228
+
229
+ # 3. Load ReCamMaster checkpoint
230
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
231
+ pipe.dit.load_state_dict(state_dict, strict=True)
232
+ pipe.to("cuda")
233
+ pipe.to(dtype=torch.bfloat16)
234
+
235
+ output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}")
236
+ if not os.path.exists(output_dir):
237
+ os.makedirs(output_dir)
238
+
239
+ # 4. Prepare test data (source video, target camera, target trajectory)
240
+ dataset = TextVideoCameraDataset(
241
+ args.dataset_path,
242
+ os.path.join(args.dataset_path, "metadata.csv"),
243
+ args,
244
+ condition_frames=args.condition_frames, # 🔧 传递参数
245
+ target_frames=args.target_frames, # 🔧 传递参数
246
+ )
247
+ dataloader = torch.utils.data.DataLoader(
248
+ dataset,
249
+ shuffle=False,
250
+ batch_size=1,
251
+ num_workers=args.dataloader_num_workers
252
+ )
253
+
254
+ # 5. Inference
255
+ for batch_idx, batch in enumerate(dataloader):
256
+ target_text = batch["text"]
257
+ source_video = batch["video"]
258
+ target_camera = batch["camera"]
259
+
260
+ video = pipe(
261
+ prompt=target_text,
262
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的��景,三条腿,背景人很多,倒着走",
263
+ source_video=source_video,
264
+ target_camera=target_camera,
265
+ cfg_scale=args.cfg_scale,
266
+ num_inference_steps=50,
267
+ seed=0,
268
+ tiled=True,
269
+ condition_frames=args.condition_frames,
270
+ target_frames=args.target_frames,
271
+ )
272
+ save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5)
scripts/infer_rlbench.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import json
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ import argparse
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import torch.nn as nn
12
+
13
+
14
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
15
+ """
16
+ 从pth文件加载预编码的视频数据
17
+ Args:
18
+ pth_path: pth文件路径
19
+ start_frame: 起始帧索引(基于压缩后的latent帧数)
20
+ num_frames: 需要的帧数(基于压缩后的latent帧数)
21
+ Returns:
22
+ condition_latents: [C, T, H, W] 格式的latent tensor
23
+ """
24
+ print(f"Loading encoded video from {pth_path}")
25
+
26
+ # 加载编码数据
27
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
28
+
29
+ # 获取latent数据
30
+ full_latents = encoded_data['latents'] # [C, T, H, W]
31
+
32
+ print(f"Full latents shape: {full_latents.shape}")
33
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
34
+
35
+ # 检查帧数是否足够
36
+ if start_frame + num_frames > full_latents.shape[1]:
37
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
38
+
39
+ # 提取指定帧数
40
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
41
+
42
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
43
+
44
+ return condition_latents, encoded_data
45
+
46
+
47
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
48
+ """
49
+ 计算相机B相对于相机A的相对位姿矩阵
50
+ """
51
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
52
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
53
+
54
+ if use_torch:
55
+ if not isinstance(pose_a, torch.Tensor):
56
+ pose_a = torch.from_numpy(pose_a).float()
57
+ if not isinstance(pose_b, torch.Tensor):
58
+ pose_b = torch.from_numpy(pose_b).float()
59
+
60
+ pose_a_inv = torch.inverse(pose_a)
61
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
62
+ else:
63
+ if not isinstance(pose_a, np.ndarray):
64
+ pose_a = np.array(pose_a, dtype=np.float32)
65
+ if not isinstance(pose_b, np.ndarray):
66
+ pose_b = np.array(pose_b, dtype=np.float32)
67
+
68
+ pose_a_inv = np.linalg.inv(pose_a)
69
+ relative_pose = np.matmul(pose_b, pose_a_inv)
70
+
71
+ return relative_pose
72
+
73
+
74
+ def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
75
+ """
76
+ 从实际相机数据生成pose embeddings
77
+ Args:
78
+ cam_data: 相机外参数据
79
+ start_frame: 起始帧(原始帧索引)
80
+ condition_frames: 条件帧数(压缩后)
81
+ target_frames: 目标帧数(压缩后)
82
+ """
83
+ time_compression_ratio = 4
84
+ total_frames = condition_frames + target_frames
85
+
86
+ # 获取相机外参序列
87
+ cam_extrinsic = cam_data # [N, 4, 4]
88
+
89
+ # 计算原始帧索引
90
+ start_frame_original = start_frame * time_compression_ratio
91
+ end_frame_original = (start_frame + total_frames) * time_compression_ratio
92
+
93
+ print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
94
+
95
+ # 计算相对pose
96
+ relative_poses = []
97
+ for i in range(total_frames):
98
+ frame_idx = start_frame_original + i * time_compression_ratio
99
+ next_frame_idx = frame_idx + time_compression_ratio
100
+
101
+
102
+ cam_prev = cam_extrinsic[frame_idx]
103
+
104
+
105
+
106
+ relative_poses.append(torch.as_tensor(cam_prev)) # 取前3行
107
+
108
+ print(cam_prev)
109
+ # 组装pose embedding
110
+ pose_embedding = torch.stack(relative_poses, dim=0)
111
+ # print('pose_embedding init:',pose_embedding[0])
112
+ print('pose_embedding:',pose_embedding)
113
+ # assert False
114
+
115
+ # pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
116
+
117
+ # 添加mask信息
118
+ mask = torch.zeros(total_frames, dtype=torch.float32)
119
+ mask[:condition_frames] = 1.0 # condition frames
120
+ mask = mask.view(-1, 1)
121
+
122
+ # 组合pose和mask
123
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
124
+
125
+ print(f"Generated camera embedding shape: {camera_embedding.shape}")
126
+
127
+ return camera_embedding.to(torch.bfloat16)
128
+
129
+
130
+ def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
131
+ """
132
+ 根据指定方向生成相机pose序列(合成数据)
133
+ """
134
+ time_compression_ratio = 4
135
+ total_frames = condition_frames + target_frames
136
+
137
+ poses = []
138
+
139
+ for i in range(total_frames):
140
+ t = i / max(1, total_frames - 1) # 0 to 1
141
+
142
+ # 创建变换矩阵
143
+ pose = np.eye(4, dtype=np.float32)
144
+
145
+ if direction == "forward":
146
+ # 前进:沿z轴负方向移动
147
+ pose[2, 3] = -t * 0.04
148
+ print('forward!')
149
+
150
+ elif direction == "backward":
151
+ # 后退:沿z轴正方向移动
152
+ pose[2, 3] = t * 2.0
153
+
154
+ elif direction == "left_turn":
155
+ # 左转:前进 + 绕y轴旋转
156
+ pose[2, 3] = -t * 0.03 # 前进
157
+ pose[0, 3] = t * 0.02 # 左移
158
+ # 添加旋转
159
+ yaw = t * 1
160
+ pose[0, 0] = np.cos(yaw)
161
+ pose[0, 2] = np.sin(yaw)
162
+ pose[2, 0] = -np.sin(yaw)
163
+ pose[2, 2] = np.cos(yaw)
164
+
165
+ elif direction == "right_turn":
166
+ # 右转:前进 + 绕y轴反向旋转
167
+ pose[2, 3] = -t * 0.03 # 前进
168
+ pose[0, 3] = -t * 0.02 # 右移
169
+ # 添加旋转
170
+ yaw = - t * 1
171
+ pose[0, 0] = np.cos(yaw)
172
+ pose[0, 2] = np.sin(yaw)
173
+ pose[2, 0] = -np.sin(yaw)
174
+ pose[2, 2] = np.cos(yaw)
175
+
176
+ poses.append(pose)
177
+
178
+ # 计算相对pose
179
+ relative_poses = []
180
+ for i in range(len(poses) - 1):
181
+ relative_pose = compute_relative_pose(poses[i], poses[i + 1])
182
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
183
+
184
+ # 为了匹配模型输入,需要确保帧数正确
185
+ if len(relative_poses) < total_frames:
186
+ # 补充最后一帧
187
+ relative_poses.append(relative_poses[-1])
188
+
189
+ pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
190
+
191
+ print('pose_embedding init:',pose_embedding[0])
192
+
193
+ print('pose_embedding:',pose_embedding[-5:])
194
+
195
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
196
+
197
+ # 添加mask信息
198
+ mask = torch.zeros(total_frames, dtype=torch.float32)
199
+ mask[:condition_frames] = 1.0 # condition frames
200
+ mask = mask.view(-1, 1)
201
+
202
+ # 组合pose和mask
203
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
204
+
205
+ print(f"Generated {direction} movement poses:")
206
+ print(f" Total frames: {total_frames}")
207
+ print(f" Camera embedding shape: {camera_embedding.shape}")
208
+
209
+ return camera_embedding.to(torch.bfloat16)
210
+
211
+
212
+ def inference_sekai_video_from_pth(
213
+ condition_pth_path,
214
+ dit_path,
215
+ output_path="sekai/infer_results/output_sekai.mp4",
216
+ start_frame=0,
217
+ condition_frames=10, # 压缩后的帧数
218
+ target_frames=2, # 压缩后的帧数
219
+ device="cuda",
220
+ prompt="a robotic arm executing precise manipulation tasks on a clean, organized desk",
221
+ direction="forward",
222
+ use_real_poses=True
223
+ ):
224
+ """
225
+ 从pth文件进行Sekai视频推理
226
+ """
227
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
228
+
229
+ print(f"Setting up models for {direction} movement...")
230
+
231
+ # 1. Load models
232
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
233
+ model_manager.load_models([
234
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
235
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
236
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
237
+ ])
238
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
239
+
240
+ # Add camera components to DiT
241
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
242
+ for block in pipe.dit.blocks:
243
+ block.cam_encoder = nn.Linear(30, dim) # 13维embedding (12D pose + 1D mask)
244
+ block.projector = nn.Linear(dim, dim)
245
+ block.cam_encoder.weight.data.zero_()
246
+ block.cam_encoder.bias.data.zero_()
247
+ block.projector.weight = nn.Parameter(torch.eye(dim))
248
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
249
+
250
+ # Load trained DiT weights
251
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
252
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
253
+ pipe = pipe.to(device)
254
+ pipe.scheduler.set_timesteps(50)
255
+
256
+ print("Loading condition video from pth...")
257
+
258
+ # Load condition video from pth
259
+ condition_latents, encoded_data = load_encoded_video_from_pth(
260
+ condition_pth_path,
261
+ start_frame=start_frame,
262
+ num_frames=condition_frames
263
+ )
264
+
265
+ condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
266
+
267
+ print("Processing poses...")
268
+
269
+ # 生成相机pose embedding
270
+ if use_real_poses and 'cam_emb' in encoded_data:
271
+ print("Using real camera poses from data")
272
+ camera_embedding = generate_camera_poses_from_data(
273
+ encoded_data['cam_emb'],
274
+ start_frame=start_frame,
275
+ condition_frames=condition_frames,
276
+ target_frames=target_frames
277
+ )
278
+ else:
279
+ print(f"Using synthetic {direction} poses")
280
+ camera_embedding = generate_camera_poses(
281
+ direction=direction,
282
+ target_frames=target_frames,
283
+ condition_frames=condition_frames
284
+ )
285
+
286
+
287
+
288
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
289
+
290
+ print(f"Camera embedding shape: {camera_embedding.shape}")
291
+
292
+ print("Encoding prompt...")
293
+
294
+ # Encode text prompt
295
+ prompt_emb = pipe.encode_prompt(prompt)
296
+
297
+ print("Generating video...")
298
+
299
+ # Generate target latents
300
+ batch_size = 1
301
+ channels = condition_latents.shape[1]
302
+ latent_height = condition_latents.shape[3]
303
+ latent_width = condition_latents.shape[4]
304
+
305
+ # 空间裁剪以节省内存(如果需要)
306
+ target_height, target_width = 64, 64
307
+
308
+ if latent_height > target_height or latent_width > target_width:
309
+ # 中心裁剪
310
+ h_start = (latent_height - target_height) // 2
311
+ w_start = (latent_width - target_width) // 2
312
+ condition_latents = condition_latents[:, :, :,
313
+ h_start:h_start+target_height,
314
+ w_start:w_start+target_width]
315
+ latent_height = target_height
316
+ latent_width = target_width
317
+
318
+ # Initialize target latents with noise
319
+ target_latents = torch.randn(
320
+ batch_size, channels, target_frames, latent_height, latent_width,
321
+ device=device, dtype=pipe.torch_dtype
322
+ )
323
+
324
+ print(f"Condition latents shape: {condition_latents.shape}")
325
+ print(f"Target latents shape: {target_latents.shape}")
326
+ print(f"Camera embedding shape: {camera_embedding.shape}")
327
+
328
+ # Combine condition and target latents
329
+ combined_latents = torch.cat([condition_latents, target_latents], dim=2)
330
+ print(f"Combined latents shape: {combined_latents.shape}")
331
+
332
+ # Prepare extra inputs
333
+ extra_input = pipe.prepare_extra_input(combined_latents)
334
+
335
+ # Denoising loop
336
+ timesteps = pipe.scheduler.timesteps
337
+
338
+ for i, timestep in enumerate(timesteps):
339
+ print(f"Denoising step {i+1}/{len(timesteps)}")
340
+
341
+ # Prepare timestep
342
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
343
+
344
+ # Predict noise
345
+ with torch.no_grad():
346
+ noise_pred = pipe.dit(
347
+ combined_latents,
348
+ timestep=timestep_tensor,
349
+ cam_emb=camera_embedding,
350
+ **prompt_emb,
351
+ **extra_input
352
+ )
353
+
354
+ # Update only target part
355
+ target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
356
+ target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
357
+
358
+ # Update combined latents
359
+ combined_latents[:, :, condition_frames:, :, :] = target_latents
360
+
361
+ print("Decoding video...")
362
+
363
+ # Decode final video
364
+ final_video = torch.cat([condition_latents, target_latents], dim=2)
365
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
366
+
367
+ # Save video
368
+ print(f"Saving video to {output_path}")
369
+
370
+ # Convert to numpy and save
371
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
372
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
373
+ video_np = (video_np * 255).astype(np.uint8)
374
+
375
+ with imageio.get_writer(output_path, fps=20) as writer:
376
+ for frame in video_np:
377
+ writer.append_data(frame)
378
+
379
+ print(f"Video generation completed! Saved to {output_path}")
380
+
381
+
382
+ def main():
383
+ parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
384
+ parser.add_argument("--condition_pth", type=str,
385
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench/OpenBox_demo_49/encoded_video.pth")
386
+ parser.add_argument("--start_frame", type=int, default=0,
387
+ help="Starting frame index (compressed latent frames)")
388
+ parser.add_argument("--condition_frames", type=int, default=8,
389
+ help="Number of condition frames (compressed latent frames)")
390
+ parser.add_argument("--target_frames", type=int, default=8,
391
+ help="Number of target frames to generate (compressed latent frames)")
392
+ parser.add_argument("--direction", type=str, default="left_turn",
393
+ choices=["forward", "backward", "left_turn", "right_turn"],
394
+ help="Direction of camera movement (if not using real poses)")
395
+ parser.add_argument("--use_real_poses", default=False,
396
+ help="Use real camera poses from data")
397
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/RLBench-train/step2000_dynamic.ckpt",
398
+ help="Path to trained DiT checkpoint")
399
+ parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/rlbench/infer_results/output_rl_2.mp4',
400
+ help="Output video path")
401
+ parser.add_argument("--prompt", type=str,
402
+ default="a robotic arm executing precise manipulation tasks on a clean, organized desk",
403
+ help="Text prompt for generation")
404
+ parser.add_argument("--device", type=str, default="cuda",
405
+ help="Device to run inference on")
406
+
407
+ args = parser.parse_args()
408
+
409
+ # 生成输出路径
410
+ if args.output_path is None:
411
+ pth_filename = os.path.basename(args.condition_pth)
412
+ name_parts = os.path.splitext(pth_filename)
413
+ output_dir = "rlbench/infer_results"
414
+ os.makedirs(output_dir, exist_ok=True)
415
+
416
+ if args.use_real_poses:
417
+ output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
418
+ else:
419
+ output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
420
+
421
+ output_path = os.path.join(output_dir, output_filename)
422
+ else:
423
+ output_path = args.output_path
424
+
425
+ print(f"Input pth: {args.condition_pth}")
426
+ print(f"Start frame: {args.start_frame} (compressed)")
427
+ print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
428
+ print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
429
+ print(f"Use real poses: {args.use_real_poses}")
430
+ print(f"Output video will be saved to: {output_path}")
431
+
432
+ inference_sekai_video_from_pth(
433
+ condition_pth_path=args.condition_pth,
434
+ dit_path=args.dit_path,
435
+ output_path=output_path,
436
+ start_frame=args.start_frame,
437
+ condition_frames=args.condition_frames,
438
+ target_frames=args.target_frames,
439
+ device=args.device,
440
+ prompt=args.prompt,
441
+ direction=args.direction,
442
+ use_real_poses=args.use_real_poses
443
+ )
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()