diff --git a/icons/move_backward.png b/icons/move_backward.png new file mode 100644 index 0000000000000000000000000000000000000000..5fe0951edbbadced88c96ba5b17dd4ebc3feb719 Binary files /dev/null and b/icons/move_backward.png differ diff --git a/icons/move_forward.png b/icons/move_forward.png new file mode 100644 index 0000000000000000000000000000000000000000..673704982275a788481b5e5aa5f38a0fe77a31c0 Binary files /dev/null and b/icons/move_forward.png differ diff --git a/icons/move_left.png b/icons/move_left.png new file mode 100644 index 0000000000000000000000000000000000000000..f1b61098b3eb46763fd7007780bef264d407dc51 Binary files /dev/null and b/icons/move_left.png differ diff --git a/icons/move_right.png b/icons/move_right.png new file mode 100644 index 0000000000000000000000000000000000000000..0f67ad8f6924ddad57698a25941e3b094554c000 Binary files /dev/null and b/icons/move_right.png differ diff --git a/icons/not_move_backward.png b/icons/not_move_backward.png new file mode 100644 index 0000000000000000000000000000000000000000..4b68435b57fb9a084503f8f79d2a099a8ee2b51c Binary files /dev/null and b/icons/not_move_backward.png differ diff --git a/icons/not_move_forward.png b/icons/not_move_forward.png new file mode 100644 index 0000000000000000000000000000000000000000..a3b2fe459a6e4a91e4aebfd78bde62b5d60085d4 Binary files /dev/null and b/icons/not_move_forward.png differ diff --git a/icons/not_move_left.png b/icons/not_move_left.png new file mode 100644 index 0000000000000000000000000000000000000000..c8f7b6c5014e86b51b2bc17aa3f39ff0dcacb858 Binary files /dev/null and b/icons/not_move_left.png differ diff --git a/icons/not_move_right.png b/icons/not_move_right.png new file mode 100644 index 0000000000000000000000000000000000000000..6f6ea264477c90e1085c933a588ca8908a84a9f3 Binary files /dev/null and b/icons/not_move_right.png differ diff --git a/icons/not_turn_down.png b/icons/not_turn_down.png new file mode 100644 index 0000000000000000000000000000000000000000..8aba8dd9f983222a600ee870fd30cce3b8fe5ce8 Binary files /dev/null and b/icons/not_turn_down.png differ diff --git a/icons/not_turn_left.png b/icons/not_turn_left.png new file mode 100644 index 0000000000000000000000000000000000000000..b3c68c244f10f4c660178e02ecca2d3a7df575b0 Binary files /dev/null and b/icons/not_turn_left.png differ diff --git a/icons/not_turn_right.png b/icons/not_turn_right.png new file mode 100644 index 0000000000000000000000000000000000000000..273372bcd2026bc450412218d253c201b97bb721 Binary files /dev/null and b/icons/not_turn_right.png differ diff --git a/icons/not_turn_up.png b/icons/not_turn_up.png new file mode 100644 index 0000000000000000000000000000000000000000..7827f949b194d6146074486cac75e844e823ebbc Binary files /dev/null and b/icons/not_turn_up.png differ diff --git a/icons/turn_down.png b/icons/turn_down.png new file mode 100644 index 0000000000000000000000000000000000000000..0d0eee4add55fac6c325354d930428f2d9efb6d4 Binary files /dev/null and b/icons/turn_down.png differ diff --git a/icons/turn_left.png b/icons/turn_left.png new file mode 100644 index 0000000000000000000000000000000000000000..b36a090cb45d99c97a2e4496d7ab229ba55a7e3d Binary files /dev/null and b/icons/turn_left.png differ diff --git a/icons/turn_right.png b/icons/turn_right.png new file mode 100644 index 0000000000000000000000000000000000000000..ee9eccf6bca67c61e034044d9423677cd83ae0fc Binary files /dev/null and b/icons/turn_right.png differ diff --git a/icons/turn_up.png b/icons/turn_up.png new file mode 100644 index 0000000000000000000000000000000000000000..27c5f1ab9a47c02be886364548ad28030fa2781f Binary files /dev/null and b/icons/turn_up.png differ diff --git a/models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt b/models/Astra/checkpoints/Put ReCamMaster ckpt file here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/Astra/checkpoints/README.md b/models/Astra/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e304464450f284e9f13b2dc4478e843221dbc1f5 --- /dev/null +++ b/models/Astra/checkpoints/README.md @@ -0,0 +1,5 @@ +--- +license: apache-2.0 +--- +# ReCamMaster: Camera-Controlled Generative Rendering from A Single Video +Please refer to the [Github](https://github.com/KwaiVGI/ReCamMaster) README for usage. \ No newline at end of file diff --git a/scripts/add_text_emb.py b/scripts/add_text_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..99da0486caa9c4abdbf51d7ea2e86f4ce4684587 --- /dev/null +++ b/scripts/add_text_emb.py @@ -0,0 +1,161 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + + required_keys = ["latents", "cam_emb", "prompt_emb"] + + + for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))): + + scene_dir = os.path.join(scenes_path, scene_name) + save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + # print('in:',scene_dir) + # print('out:',save_dir) + + + # 检查是否已编码 + encoded_path = os.path.join(save_dir, "encoded_video.pth") + # if os.path.exists(encoded_path): + print(f"Checking scene {scene_name}...") + # continue + + # 加载场景信息 + + # print(encoded_path) + data = torch.load(encoded_path,weights_only=False) + missing_keys = [key for key in required_keys if key not in data] + + if missing_keys: + print(f"警告: 文件中缺少以下必要元素: {missing_keys}") + else: + print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb") + continue + # with np.load(scene_cam_path) as data: + # cam_data = data.files + # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + + + # 加载和编码视频 + # video_frames = encoder.load_video_frames(video_path) + # if video_frames is None: + # print(f"Failed to load video: {video_path}") + # continue + + # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + # print(video_frames.shape) + # 编码视频 + with torch.no_grad(): + # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + if processed_count == 0: + print('encode prompt!!!') + prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera + del encoder.pipe.prompter + + data["prompt_emb"] = prompt_emb + + print("已添加/更新 prompt_emb 元素") + + # 保存修改后的文件(可改为新路径避免覆盖原文件) + torch.save(data, encoded_path) + + # pdb.set_trace() + # 保存编码结果 + + + print(f"Saved encoded data: {encoded_path}") + processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + print(processed_count) + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/add_text_emb_rl.py b/scripts/add_text_emb_rl.py new file mode 100644 index 0000000000000000000000000000000000000000..29d221f0bbd5cd3054ad16e35884a36bb63bd042 --- /dev/null +++ b/scripts/add_text_emb_rl.py @@ -0,0 +1,161 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + + required_keys = ["latents", "cam_emb", "prompt_emb"] + + + for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))): + + scene_dir = os.path.join(scenes_path, scene_name) + save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + # print('in:',scene_dir) + # print('out:',save_dir) + + + # 检查是否已编码 + encoded_path = os.path.join(save_dir, "encoded_video.pth") + # if os.path.exists(encoded_path): + print(f"Checking scene {scene_name}...") + # continue + + # 加载场景信息 + + # print(encoded_path) + data = torch.load(encoded_path,weights_only=False) + missing_keys = [key for key in required_keys if key not in data] + + if missing_keys: + print(f"警告: 文件中缺少以下必要元素: {missing_keys}") + else: + print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb") + continue + # with np.load(scene_cam_path) as data: + # cam_data = data.files + # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + + + # 加载和编码视频 + # video_frames = encoder.load_video_frames(video_path) + # if video_frames is None: + # print(f"Failed to load video: {video_path}") + # continue + + # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + # print(video_frames.shape) + # 编码视频 + with torch.no_grad(): + # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + if processed_count == 0: + print('encode prompt!!!') + prompt_emb = encoder.pipe.encode_prompt("a robotic arm executing precise manipulation tasks on a clean, organized desk")#A video of a scene shot using a drone's front camera + “A video of a scene shot using a pedestrian's front camera while walking” + del encoder.pipe.prompter + + data["prompt_emb"] = prompt_emb + + print("已添加/更新 prompt_emb 元素") + + # 保存修改后的文件(可改为新路径避免覆盖原文件) + torch.save(data, encoded_path) + + # pdb.set_trace() + # 保存编码结果 + + + print(f"Saved encoded data: {encoded_path}") + processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + print(processed_count) + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/rlbench") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/add_text_emb_spatialvid.py b/scripts/add_text_emb_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..960123fabec348313c49f9f78413550c0c432480 --- /dev/null +++ b/scripts/add_text_emb_spatialvid.py @@ -0,0 +1,173 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + + required_keys = ["latents", "cam_emb", "prompt_emb"] + + + for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))): + + scene_dir = os.path.join(scenes_path, scene_name) + save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + # print('in:',scene_dir) + # print('out:',save_dir) + + + # 检查是否已编码 + encoded_path = os.path.join(save_dir, "encoded_video.pth") + # if os.path.exists(encoded_path): + # print(f"Checking scene {scene_name}...") + # continue + + # 加载场景信息 + + # print(encoded_path) + data = torch.load(encoded_path,weights_only=False, + map_location="cpu") + missing_keys = [key for key in required_keys if key not in data] + + if missing_keys: + print(f"警告: 文件 {encoded_path} 中缺少以下必要元素: {missing_keys}") + # else: + # # print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb") + # continue + # pdb.set_trace() + if data['prompt_emb']['context'].requires_grad: + print(f"警告: 文件 {encoded_path} 中存在含梯度变量,已消除") + + data['prompt_emb']['context'] = data['prompt_emb']['context'].detach().clone() + + # 双重保险:显式关闭梯度 + data['prompt_emb']['context'].requires_grad_(False) + + # 验证是否成功(可选) + assert not data['prompt_emb']['context'].requires_grad, "梯度仍未消除!" + torch.save(data, encoded_path) + # with np.load(scene_cam_path) as data: + # cam_data = data.files + # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + + + # 加载和编码视频 + # video_frames = encoder.load_video_frames(video_path) + # if video_frames is None: + # print(f"Failed to load video: {video_path}") + # continue + + # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + # print(video_frames.shape) + # 编码视频 + '''with torch.no_grad(): + # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + if processed_count == 0: + print('encode prompt!!!') + prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera + del encoder.pipe.prompter + + data["prompt_emb"] = prompt_emb + + print("已添加/更新 prompt_emb 元素") + + # 保存修改后的文件(可改为新路径避免覆盖原文件) + torch.save(data, encoded_path) + + # pdb.set_trace() + # 保存编码结果 + + print(f"Saved encoded data: {encoded_path}")''' + processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + print(processed_count) + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/analyze_openx.py b/scripts/analyze_openx.py new file mode 100644 index 0000000000000000000000000000000000000000..90dbc60ac966df80e42a5fd20dc19b40eb610a69 --- /dev/null +++ b/scripts/analyze_openx.py @@ -0,0 +1,243 @@ +import os +import torch +from tqdm import tqdm + +def analyze_openx_dataset_frame_counts(dataset_path): + """分析OpenX数据集中的帧数分布""" + + print(f"🔧 分析OpenX数据集: {dataset_path}") + + if not os.path.exists(dataset_path): + print(f" ⚠️ 路径不存在: {dataset_path}") + return + + episode_dirs = [] + total_episodes = 0 + valid_episodes = 0 + + # 收集所有episode目录 + for item in os.listdir(dataset_path): + episode_dir = os.path.join(dataset_path, item) + if os.path.isdir(episode_dir): + total_episodes += 1 + encoded_path = os.path.join(episode_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + episode_dirs.append(episode_dir) + valid_episodes += 1 + + print(f"📊 总episode数: {total_episodes}") + print(f"📊 有效episode数: {valid_episodes}") + + if len(episode_dirs) == 0: + print("❌ 没有找到有效的episode") + return + + # 统计帧数分布 + frame_counts = [] + less_than_10 = 0 + less_than_8 = 0 + less_than_5 = 0 + error_count = 0 + + print("🔧 开始分析帧数分布...") + + for episode_dir in tqdm(episode_dirs, desc="分析episodes"): + try: + encoded_data = torch.load( + os.path.join(episode_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + latents = encoded_data['latents'] # [C, T, H, W] + frame_count = latents.shape[1] # T维度 + frame_counts.append(frame_count) + + if frame_count < 10: + less_than_10 += 1 + if frame_count < 8: + less_than_8 += 1 + if frame_count < 5: + less_than_5 += 1 + + except Exception as e: + error_count += 1 + if error_count <= 5: # 只打印前5个错误 + print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}") + + # 统计结果 + total_valid = len(frame_counts) + print(f"\n📈 帧数分布统计:") + print(f" 总有效episodes: {total_valid}") + print(f" 错误episodes: {error_count}") + print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}") + print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}") + print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0) + + print(f"\n🎯 关键统计:") + print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)") + print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)") + print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)") + print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)") + + # 详细分布 + frame_counts.sort() + print(f"\n📊 详细帧数分布:") + + # 按范围统计 + ranges = [ + (1, 4, "1-4帧"), + (5, 7, "5-7帧"), + (8, 9, "8-9帧"), + (10, 19, "10-19帧"), + (20, 49, "20-49帧"), + (50, 99, "50-99帧"), + (100, float('inf'), "100+帧") + ] + + for min_f, max_f, label in ranges: + count = sum(1 for f in frame_counts if min_f <= f <= max_f) + percentage = count / total_valid * 100 + print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)") + + # 建议的训练配置 + print(f"\n💡 训练配置建议:") + time_compression_ratio = 4 + min_condition_compressed = 4 // time_compression_ratio # 1帧 + target_frames_compressed = 32 // time_compression_ratio # 8帧 + min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧 + + usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed) + usable_percentage = usable_episodes / total_valid * 100 + + print(f" 最小条件帧数(压缩后): {min_condition_compressed}") + print(f" 目标帧数(压缩后): {target_frames_compressed}") + print(f" 最小所需帧数(压缩后): {min_required_compressed}") + print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)") + + # 保存详细统计到文件 + output_file = os.path.join(dataset_path, "frame_count_analysis.txt") + with open(output_file, 'w') as f: + f.write(f"OpenX Dataset Frame Count Analysis\n") + f.write(f"Dataset Path: {dataset_path}\n") + f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n") + + f.write(f"Total Episodes: {total_episodes}\n") + f.write(f"Valid Episodes: {total_valid}\n") + f.write(f"Error Episodes: {error_count}\n\n") + + f.write(f"Frame Count Statistics:\n") + f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n") + f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n") + f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n") + + f.write(f"Key Statistics:\n") + f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n") + f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n") + f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n") + f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n") + + f.write(f"Detailed Distribution:\n") + for min_f, max_f, label in ranges: + count = sum(1 for f in frame_counts if min_f <= f <= max_f) + percentage = count / total_valid * 100 + f.write(f" {label}: {count} ({percentage:.2f}%)\n") + + f.write(f"\nTraining Configuration Recommendation:\n") + f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n") + + # 写入所有帧数 + f.write(f"\nAll Frame Counts:\n") + for i, count in enumerate(frame_counts): + f.write(f"{count}") + if (i + 1) % 20 == 0: + f.write("\n") + else: + f.write(", ") + + print(f"\n💾 详细统计已保存到: {output_file}") + + return { + 'total_valid': total_valid, + 'less_than_10': less_than_10, + 'less_than_8': less_than_8, + 'less_than_5': less_than_5, + 'frame_counts': frame_counts, + 'usable_episodes': usable_episodes + } + +def quick_sample_analysis(dataset_path, sample_size=1000): + """快速采样分析,用于大数据集的初步估计""" + + print(f"🚀 快速采样分析 (样本数: {sample_size})") + + episode_dirs = [] + for item in os.listdir(dataset_path): + episode_dir = os.path.join(dataset_path, item) + if os.path.isdir(episode_dir): + encoded_path = os.path.join(episode_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + episode_dirs.append(episode_dir) + + if len(episode_dirs) == 0: + print("❌ 没有找到有效的episode") + return + + # 随机采样 + import random + sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs))) + + frame_counts = [] + less_than_10 = 0 + + for episode_dir in tqdm(sample_dirs, desc="采样分析"): + try: + encoded_data = torch.load( + os.path.join(episode_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + frame_count = encoded_data['latents'].shape[1] + frame_counts.append(frame_count) + + if frame_count < 10: + less_than_10 += 1 + + except Exception as e: + continue + + total_sample = len(frame_counts) + percentage_less_than_10 = less_than_10 / total_sample * 100 + + print(f"📊 采样结果:") + print(f" 采样数量: {total_sample}") + print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)") + print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)") + print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}") + + # 估算全数据集 + total_episodes = len(episode_dirs) + estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100) + + print(f"\n🔮 全数据集估算:") + print(f" 总episodes: {total_episodes}") + print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)") + print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)") + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布") + parser.add_argument("--dataset_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", + help="OpenX编码数据集路径") + parser.add_argument("--quick", action="store_true", help="快速采样分析模式") + parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量") + + args = parser.parse_args() + + if args.quick: + quick_sample_analysis(args.dataset_path, args.sample_size) + else: + analyze_openx_dataset_frame_counts(args.dataset_path) \ No newline at end of file diff --git a/scripts/analyze_pose.py b/scripts/analyze_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..50df339bef310547371707ec5288d75cf4e3d4e7 --- /dev/null +++ b/scripts/analyze_pose.py @@ -0,0 +1,188 @@ +import os +import json +import matplotlib.pyplot as plt +import numpy as np +from pose_classifier import PoseClassifier +import torch +from collections import defaultdict + +def analyze_turning_patterns_detailed(dataset_path, num_samples=50): + """详细分析转弯模式,基于相对于reference的pose变化""" + classifier = PoseClassifier() + samples_path = os.path.join(dataset_path, "samples") + + all_analyses = [] + sample_count = 0 + + # 用于统计每个类别的样本 + class_samples = defaultdict(list) + + print("=== 开始分析样本(基于相对于reference的变化)===") + + for item in sorted(os.listdir(samples_path)): # 排序以便有序输出 + if sample_count >= num_samples: + break + + sample_dir = os.path.join(samples_path, item) + if os.path.isdir(sample_dir): + poses_path = os.path.join(sample_dir, "poses.json") + if os.path.exists(poses_path): + try: + with open(poses_path, 'r') as f: + poses_data = json.load(f) + + target_relative_poses = poses_data['target_relative_poses'] + + if len(target_relative_poses) > 0: + # 🔧 创建相对pose向量(已经是相对于reference的) + pose_vecs = [] + for pose_data in target_relative_poses: + # 相对位移(已经是相对于reference计算的) + translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32) + + # 🔧 相对旋转(需要从current和reference计算) + current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32) + reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32) + + # 计算相对旋转:q_relative = q_ref^-1 * q_current + relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation) + + # 组合为7D向量:[relative_translation, relative_rotation] + pose_vec = torch.cat([translation, relative_rotation], dim=0) + pose_vecs.append(pose_vec) + + if pose_vecs: + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 🔧 使用新的分析方法 + analysis = classifier.analyze_pose_sequence(pose_sequence) + analysis['sample_name'] = item + all_analyses.append(analysis) + + # 🔧 详细输出每个样本的分类信息 + print(f"\n--- 样本 {sample_count + 1}: {item} ---") + print(f"总帧数: {analysis['total_frames']}") + print(f"总距离: {analysis['total_distance']:.4f}") + + # 分类分布 + class_dist = analysis['class_distribution'] + print(f"分类分布:") + for class_name, count in class_dist.items(): + percentage = count / analysis['total_frames'] * 100 + print(f" {class_name}: {count} 帧 ({percentage:.1f}%)") + + # 🔧 调试前几个pose的分类过程 + print(f"前3帧的详细分类过程:") + for i in range(min(3, len(pose_vecs))): + debug_info = classifier.debug_single_pose( + pose_vecs[i][:3], pose_vecs[i][3:7] + ) + print(f" 帧{i}: {debug_info['classification']} " + f"(yaw: {debug_info['yaw_angle_deg']:.2f}°, " + f"forward: {debug_info['forward_movement']:.3f})") + + # 运动段落 + print(f"运动段落:") + for i, segment in enumerate(analysis['motion_segments']): + print(f" 段落{i+1}: {segment['class']} (帧 {segment['start_frame']}-{segment['end_frame']}, 持续 {segment['duration']} 帧)") + + # 🔧 确定主要运动类型 + dominant_class = max(class_dist.items(), key=lambda x: x[1]) + dominant_class_name = dominant_class[0] + dominant_percentage = dominant_class[1] / analysis['total_frames'] * 100 + + print(f"主要运动类型: {dominant_class_name} ({dominant_percentage:.1f}%)") + + # 将样本添加到对应类别 + class_samples[dominant_class_name].append({ + 'name': item, + 'percentage': dominant_percentage, + 'analysis': analysis + }) + + sample_count += 1 + + except Exception as e: + print(f"❌ 处理样本 {item} 时出错: {e}") + + print("\n" + "="*60) + print("=== 按类别分组的样本统计(基于相对于reference的变化)===") + + # 🔧 按类别输出样本列表 + for class_name in ['forward', 'backward', 'left_turn', 'right_turn']: + samples = class_samples[class_name] + print(f"\n🔸 {class_name.upper()} 类样本 (共 {len(samples)} 个):") + + if samples: + # 按主要类别占比排序 + samples.sort(key=lambda x: x['percentage'], reverse=True) + + for i, sample_info in enumerate(samples, 1): + print(f" {i:2d}. {sample_info['name']} ({sample_info['percentage']:.1f}%)") + + # 显示详细的段落信息 + segments = sample_info['analysis']['motion_segments'] + segment_summary = [] + for seg in segments: + if seg['duration'] >= 2: # 只显示持续时间>=2帧的段落 + segment_summary.append(f"{seg['class']}({seg['duration']})") + + if segment_summary: + print(f" 段落: {' -> '.join(segment_summary)}") + else: + print(" (无样本)") + + # 🔧 统计总体模式 + print(f"\n" + "="*60) + print("=== 总体统计 ===") + + total_forward = sum(a['class_distribution']['forward'] for a in all_analyses) + total_backward = sum(a['class_distribution']['backward'] for a in all_analyses) + total_left_turn = sum(a['class_distribution']['left_turn'] for a in all_analyses) + total_right_turn = sum(a['class_distribution']['right_turn'] for a in all_analyses) + total_frames = total_forward + total_backward + total_left_turn + total_right_turn + + print(f"总样本数: {len(all_analyses)}") + print(f"总帧数: {total_frames}") + print(f"Forward: {total_forward} 帧 ({total_forward/total_frames*100:.1f}%)") + print(f"Backward: {total_backward} 帧 ({total_backward/total_frames*100:.1f}%)") + print(f"Left Turn: {total_left_turn} 帧 ({total_left_turn/total_frames*100:.1f}%)") + print(f"Right Turn: {total_right_turn} 帧 ({total_right_turn/total_frames*100:.1f}%)") + + # 🔧 样本分布统计 + print(f"\n按主要类型的样本分布:") + for class_name in ['forward', 'backward', 'left_turn', 'right_turn']: + count = len(class_samples[class_name]) + percentage = count / len(all_analyses) * 100 if all_analyses else 0 + print(f" {class_name}: {count} 样本 ({percentage:.1f}%)") + + return all_analyses, class_samples + +def calculate_relative_rotation(current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + # 计算参考旋转的逆 (q_ref^-1) + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + +if __name__ == "__main__": + dataset_path = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_2" + + print("开始详细分析pose分类(基于相对于reference的变化)...") + all_analyses, class_samples = analyze_turning_patterns_detailed(dataset_path, num_samples=4000) + + print(f"\n🎉 分析完成! 共处理 {len(all_analyses)} 个样本") \ No newline at end of file diff --git a/scripts/batch_drone.py b/scripts/batch_drone.py new file mode 100644 index 0000000000000000000000000000000000000000..085c773efb15408d55e90dc0622eccfbbe64a1fc --- /dev/null +++ b/scripts/batch_drone.py @@ -0,0 +1,44 @@ +import os +import random +import subprocess +import time + +src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid" +dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_drone_first" +infer_script = "/home/zhuyixuan05/ReCamMaster/infer_origin.py" # 修改为你的实际路径 + +while True: + # 随机选择一个子文件夹 + subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))] + if not subdirs: + print("没有可用的子文件夹") + break + chosen = random.choice(subdirs) + chosen_dir = os.path.join(src_root, chosen) + pth_file = os.path.join(chosen_dir, "encoded_video.pth") + if not os.path.exists(pth_file): + print(f"{pth_file} 不存在,跳过") + continue + + # 生成输出文件名 + out_file = os.path.join(dst_root, f"{chosen}.mp4") + print(f"开始生成: {pth_file} -> {out_file}") + + # 构造命令 + cmd = [ + "python", infer_script, + "--condition_pth", pth_file, + "--output_path", out_file, + "--prompt", "exploring the world", + "--modality_type", "sekai", + "--direction", "right", + "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt", + "--use_gt_prompt" + ] + + # 仅使用第二张 GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "0" + + # 执行推理 + subprocess.run(cmd, env=env) \ No newline at end of file diff --git a/scripts/batch_infer.py b/scripts/batch_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..35857c3834565c95097feb1743b120aaa373a5d3 --- /dev/null +++ b/scripts/batch_infer.py @@ -0,0 +1,186 @@ +import os +import subprocess +import argparse +from pathlib import Path +import glob + +def find_video_files(videos_dir): + """查找视频目录下的所有视频文件""" + video_extensions = ['.mp4'] + video_files = [] + + for ext in video_extensions: + pattern = os.path.join(videos_dir, f"*{ext}") + video_files.extend(glob.glob(pattern)) + + return sorted(video_files) + +def run_inference(condition_video, direction, dit_path, output_dir): + """运行单个推理任务""" + # 构建输出文件名 + input_filename = os.path.basename(condition_video) + name_parts = os.path.splitext(input_filename) + output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}" + output_path = os.path.join(output_dir, output_filename) + + # 构建推理命令 + cmd = [ + "python", "infer_nus.py", + "--condition_video", condition_video, + "--direction", direction, + "--dit_path", dit_path, + "--output_path", output_path, + ] + + print(f"🎬 生成 {direction} 方向视频: {input_filename} -> {output_filename}") + print(f" 命令: {' '.join(cmd)}") + + try: + # 运行推理 + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + print(f"✅ 成功生成: {output_path}") + return True + except subprocess.CalledProcessError as e: + print(f"❌ 生成失败: {e}") + print(f" 错误输出: {e.stderr}") + return False + +def batch_inference(args): + """批量推理主函数""" + videos_dir = args.videos_dir + output_dir = args.output_dir + directions = args.directions + dit_path = args.dit_path + + # 检查输入目录 + if not os.path.exists(videos_dir): + print(f"❌ 视频目录不存在: {videos_dir}") + return + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + print(f"📁 输出目录: {output_dir}") + + # 查找所有视频文件 + video_files = find_video_files(videos_dir) + + if not video_files: + print(f"❌ 在 {videos_dir} 中没有找到视频文件") + return + + print(f"🎥 找到 {len(video_files)} 个视频文件:") + for video in video_files: + print(f" - {os.path.basename(video)}") + + print(f"🎯 将为每个视频生成以下方向: {', '.join(directions)}") + print(f"📊 总共将生成 {len(video_files) * len(directions)} 个视频") + + # 统计信息 + total_tasks = len(video_files) * len(directions) + completed_tasks = 0 + failed_tasks = 0 + + # 批量处理 + for i, video_file in enumerate(video_files, 1): + print(f"\n{'='*60}") + print(f"处理视频 {i}/{len(video_files)}: {os.path.basename(video_file)}") + print(f"{'='*60}") + + for j, direction in enumerate(directions, 1): + print(f"\n--- 方向 {j}/{len(directions)}: {direction} ---") + + # 检查输出文件是否已存在 + input_filename = os.path.basename(video_file) + name_parts = os.path.splitext(input_filename) + output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}" + output_path = os.path.join(output_dir, output_filename) + + if os.path.exists(output_path) and not args.overwrite: + print(f"⏭️ 文件已存在,跳过: {output_filename}") + completed_tasks += 1 + continue + + # 运行推理 + success = run_inference( + condition_video=video_file, + direction=direction, + dit_path=dit_path, + output_dir=output_dir, + ) + + if success: + completed_tasks += 1 + else: + failed_tasks += 1 + + # 显示进度 + current_progress = completed_tasks + failed_tasks + print(f"📈 进度: {current_progress}/{total_tasks} " + f"(成功: {completed_tasks}, 失败: {failed_tasks})") + + # 最终统计 + print(f"\n{'='*60}") + print(f"🎉 批量推理完成!") + print(f"📊 总任务数: {total_tasks}") + print(f"✅ 成功: {completed_tasks}") + print(f"❌ 失败: {failed_tasks}") + print(f"📁 输出目录: {output_dir}") + + if failed_tasks > 0: + print(f"⚠️ 有 {failed_tasks} 个任务失败,请检查日志") + + # 列出生成的文件 + if completed_tasks > 0: + print(f"\n📋 生成的文件:") + generated_files = glob.glob(os.path.join(output_dir, "*.mp4")) + for file_path in sorted(generated_files): + print(f" - {os.path.basename(file_path)}") + +def main(): + parser = argparse.ArgumentParser(description="批量对nus/videos目录下的所有视频生成不同方向的输出") + + parser.add_argument("--videos_dir", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032", + help="输入视频目录路径") + + parser.add_argument("--output_dir", type=str, default="nus/infer_results/batch_dynamic_4032_noise", + help="输出视频目录路径") + + parser.add_argument("--directions", nargs="+", + default=["left_turn", "right_turn"], + choices=["forward", "backward", "left_turn", "right_turn"], + help="要生成的方向列表") + + parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt", + help="训练好的DiT模型路径") + + parser.add_argument("--overwrite", action="store_true", + help="是否覆盖已存在的输出文件") + + parser.add_argument("--dry_run", action="store_true", + help="只显示将要执行的任务,不实际运行") + + args = parser.parse_args() + + if args.dry_run: + print("🔍 预览模式 - 只显示任务,不执行") + videos_dir = args.videos_dir + video_files = find_video_files(videos_dir) + + print(f"📁 输入目录: {videos_dir}") + print(f"📁 输出目录: {args.output_dir}") + print(f"🎥 找到视频: {len(video_files)} 个") + print(f"🎯 生成方向: {', '.join(args.directions)}") + print(f"📊 总任务数: {len(video_files) * len(args.directions)}") + + print(f"\n将要执行的任务:") + for video in video_files: + for direction in args.directions: + input_name = os.path.basename(video) + name_parts = os.path.splitext(input_name) + output_name = f"{name_parts[0]}_{direction}{name_parts[1]}" + print(f" {input_name} -> {output_name} ({direction})") + else: + batch_inference(args) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/batch_nus.py b/scripts/batch_nus.py new file mode 100644 index 0000000000000000000000000000000000000000..6322b3670ddc1d76bbd01b1669674f4345a1fc00 --- /dev/null +++ b/scripts/batch_nus.py @@ -0,0 +1,42 @@ +import os +import random +import subprocess +import time + +src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes" +dst_root = "/share_zhuyixuan05/zhuyixuan05/New_nus_right_2" +infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径 + +while True: + # 随机选择一个子文件夹 + subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))] + if not subdirs: + print("没有可用的子文件夹") + break + chosen = random.choice(subdirs) + chosen_dir = os.path.join(src_root, chosen) + pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth") + if not os.path.exists(pth_file): + print(f"{pth_file} 不存在,跳过") + continue + + # 生成输出文件名 + out_file = os.path.join(dst_root, f"{chosen}.mp4") + print(f"开始生成: {pth_file} -> {out_file}") + + # 构造命令 + cmd = [ + "python", infer_script, + "--condition_pth", pth_file, + "--output_path", out_file, + "--prompt", "a car is driving", + "--modality_type", "nuscenes", + "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt" + ] + + # 仅使用第二张 GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1" + + # 执行推理 + subprocess.run(cmd, env=env) \ No newline at end of file diff --git a/scripts/batch_rt.py b/scripts/batch_rt.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9781a256359ede171726bddb89542ae17aa4c9 --- /dev/null +++ b/scripts/batch_rt.py @@ -0,0 +1,41 @@ +import os +import random +import subprocess +import time + +src_root = "/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded" +dst_root = "/share_zhuyixuan05/zhuyixuan05/New_RT" +infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径 + +while True: + # 随机选择一个子文件夹 + subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))] + if not subdirs: + print("没有可用的子文件夹") + break + chosen = random.choice(subdirs) + chosen_dir = os.path.join(src_root, chosen) + pth_file = os.path.join(chosen_dir, "encoded_video.pth") + if not os.path.exists(pth_file): + print(f"{pth_file} 不存在,跳过") + continue + + # 生成输出文件名 + out_file = os.path.join(dst_root, f"{chosen}.mp4") + print(f"开始生成: {pth_file} -> {out_file}") + + # 构造命令 + cmd = [ + "python", infer_script, + "--condition_pth", pth_file, + "--output_path", out_file, + "--prompt", "A robotic arm is moving the object", + "--modality_type", "openx", + ] + + # 仅使用第二张 GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1" + + # 执行推理 + subprocess.run(cmd, env=env) \ No newline at end of file diff --git a/scripts/batch_spa.py b/scripts/batch_spa.py new file mode 100644 index 0000000000000000000000000000000000000000..b86102e39654a9a4575b22dbf1791c6304de0b4f --- /dev/null +++ b/scripts/batch_spa.py @@ -0,0 +1,43 @@ +import os +import random +import subprocess +import time + +src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid" +dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_right" +infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径 + +while True: + # 随机选择一个子文件夹 + subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))] + if not subdirs: + print("没有可用的子文件夹") + break + chosen = random.choice(subdirs) + chosen_dir = os.path.join(src_root, chosen) + pth_file = os.path.join(chosen_dir, "encoded_video.pth") + if not os.path.exists(pth_file): + print(f"{pth_file} 不存在,跳过") + continue + + # 生成输出文件名 + out_file = os.path.join(dst_root, f"{chosen}.mp4") + print(f"开始生成: {pth_file} -> {out_file}") + + # 构造命令 + cmd = [ + "python", infer_script, + "--condition_pth", pth_file, + "--output_path", out_file, + "--prompt", "exploring the world", + "--modality_type", "sekai", + #"--direction", "left", + "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt" + ] + + # 仅使用第二张 GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "0" + + # 执行推理 + subprocess.run(cmd, env=env) \ No newline at end of file diff --git a/scripts/batch_walk.py b/scripts/batch_walk.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c7111314f9010b0f59c894d220801229c1662b --- /dev/null +++ b/scripts/batch_walk.py @@ -0,0 +1,42 @@ +import os +import random +import subprocess +import time + +src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes" +dst_root = "/share_zhuyixuan05/zhuyixuan05/New_walk" +infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径 + +while True: + # 随机选择一个子文件夹 + subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))] + if not subdirs: + print("没有可用的子文件夹") + break + chosen = random.choice(subdirs) + chosen_dir = os.path.join(src_root, chosen) + pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth") + if not os.path.exists(pth_file): + print(f"{pth_file} 不存在,跳过") + continue + + # 生成输出文件名 + out_file = os.path.join(dst_root, f"{chosen}.mp4") + print(f"开始生成: {pth_file} -> {out_file}") + + # 构造命令 + cmd = [ + "python", infer_script, + "--condition_pth", pth_file, + "--output_path", out_file, + "--prompt", "a car is driving", + "--modality_type", "nuscenes", + "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt" + ] + + # 仅使用第二张 GPU + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1" + + # 执行推理 + subprocess.run(cmd, env=env) \ No newline at end of file diff --git a/scripts/check.py b/scripts/check.py new file mode 100644 index 0000000000000000000000000000000000000000..4360b06e388a799451da00dcebdd6e8e4ab26191 --- /dev/null +++ b/scripts/check.py @@ -0,0 +1,263 @@ +import torch +import os +import argparse +from collections import defaultdict +import time + +def load_checkpoint(ckpt_path): + """加载检查点文件""" + if not os.path.exists(ckpt_path): + return None + + try: + state_dict = torch.load(ckpt_path, map_location='cpu') + return state_dict + except Exception as e: + print(f"❌ 加载检查点失败: {e}") + return None + +def compare_parameters(state_dict1, state_dict2, threshold=1e-8): + """比较两个状态字典的参数差异""" + if state_dict1 is None or state_dict2 is None: + return None + + updated_params = {} + unchanged_params = {} + + for name, param1 in state_dict1.items(): + if name in state_dict2: + param2 = state_dict2[name] + + # 计算参数差异 + diff = torch.abs(param1 - param2) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + + if max_diff > threshold: + updated_params[name] = { + 'max_diff': max_diff, + 'mean_diff': mean_diff, + 'shape': param1.shape + } + else: + unchanged_params[name] = { + 'max_diff': max_diff, + 'mean_diff': mean_diff, + 'shape': param1.shape + } + + return updated_params, unchanged_params + +def categorize_parameters(param_dict): + """将参数按类型分类""" + categories = { + 'moe_related': {}, + 'camera_related': {}, + 'framepack_related': {}, + 'attention': {}, + 'other': {} + } + + for name, info in param_dict.items(): + if any(keyword in name.lower() for keyword in ['moe', 'gate', 'expert', 'processor']): + categories['moe_related'][name] = info + elif any(keyword in name.lower() for keyword in ['cam_encoder', 'projector', 'camera']): + categories['camera_related'][name] = info + elif any(keyword in name.lower() for keyword in ['clean_x_embedder', 'framepack']): + categories['framepack_related'][name] = info + elif any(keyword in name.lower() for keyword in ['attn', 'attention']): + categories['attention'][name] = info + else: + categories['other'][name] = info + + return categories + +def print_category_summary(category_name, params, color_code=''): + """打印某类参数的摘要""" + if not params: + print(f"{color_code} {category_name}: 无参数") + return + + total_params = len(params) + max_diffs = [info['max_diff'] for info in params.values()] + mean_diffs = [info['mean_diff'] for info in params.values()] + + print(f"{color_code} {category_name} ({total_params} 个参数):") + print(f" 最大差异范围: {min(max_diffs):.2e} ~ {max(max_diffs):.2e}") + print(f" 平均差异范围: {min(mean_diffs):.2e} ~ {max(mean_diffs):.2e}") + + # 显示前5个最大变化的参数 + sorted_params = sorted(params.items(), key=lambda x: x[1]['max_diff'], reverse=True) + print(f" 变化最大的参数:") + for i, (name, info) in enumerate(sorted_params[:100]): + shape_str = 'x'.join(map(str, info['shape'])) + print(f" {i+1}. {name} [{shape_str}]: max_diff={info['max_diff']:.2e}") + +def monitor_training(checkpoint_dir, check_interval=60): + """监控训练过程中的参数更新""" + print(f"🔍 开始监控训练进度...") + print(f"📁 检查点目录: {checkpoint_dir}") + print(f"⏰ 检查间隔: {check_interval}秒") + print("=" * 80) + + previous_ckpt = None + previous_step = -1 + + while True: + try: + # 查找最新的检查点 + if not os.path.exists(checkpoint_dir): + print(f"❌ 检查点目录不存在: {checkpoint_dir}") + time.sleep(check_interval) + continue + + ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('step') and f.endswith('.ckpt')] + if not ckpt_files: + print("⏳ 未找到检查点文件,等待中...") + time.sleep(check_interval) + continue + + # 按步数排序,获取最新的 + ckpt_files.sort(key=lambda x: int(x.replace('step', '').replace('.ckpt', ''))) + latest_ckpt_file = ckpt_files[-1] + latest_ckpt_path = os.path.join(checkpoint_dir, latest_ckpt_file) + + # 提取步数 + current_step = int(latest_ckpt_file.replace('step', '').replace('.ckpt', '')) + + if current_step <= previous_step: + print(f"⏳ 等待新的检查点... (当前: step{current_step})") + time.sleep(check_interval) + continue + + print(f"\n🔍 发现新检查点: {latest_ckpt_file}") + + # 加载当前检查点 + current_state_dict = load_checkpoint(latest_ckpt_path) + if current_state_dict is None: + print("❌ 无法加载当前检查点") + time.sleep(check_interval) + continue + + if previous_ckpt is not None: + print(f"📊 比较 step{previous_step} -> step{current_step}") + + # 比较参数 + updated_params, unchanged_params = compare_parameters( + previous_ckpt, current_state_dict, threshold=1e-8 + ) + + if updated_params is None: + print("❌ 参数比较失败") + else: + # 分类显示结果 + updated_categories = categorize_parameters(updated_params) + unchanged_categories = categorize_parameters(unchanged_params) + + print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):") + print_category_summary("MoE相关", updated_categories['moe_related'], '🔥') + print_category_summary("Camera相关", updated_categories['camera_related'], '📷') + print_category_summary("FramePack相关", updated_categories['framepack_related'], '🎞️') + print_category_summary("注意力相关", updated_categories['attention'], '👁️') + print_category_summary("其他", updated_categories['other'], '📦') + + print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):") + print_category_summary("MoE相关", unchanged_categories['moe_related'], '❄️') + print_category_summary("Camera相关", unchanged_categories['camera_related'], '❄️') + print_category_summary("FramePack相关", unchanged_categories['framepack_related'], '❄️') + print_category_summary("注意力相关", unchanged_categories['attention'], '❄️') + print_category_summary("其他", unchanged_categories['other'], '❄️') + + # 检查关键组件是否在更新 + critical_keywords = ['moe', 'cam_encoder', 'projector', 'clean_x_embedder'] + critical_updated = any( + any(keyword in name.lower() for keyword in critical_keywords) + for name in updated_params.keys() + ) + + if critical_updated: + print("\n✅ 关键组件正在更新!") + else: + print("\n❌ 警告:关键组件可能未在更新!") + + # 计算更新率 + total_params = len(updated_params) + len(unchanged_params) + update_rate = len(updated_params) / total_params * 100 + print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})") + + # 保存当前状态用于下次比较 + previous_ckpt = current_state_dict + previous_step = current_step + + print("=" * 80) + time.sleep(check_interval) + + except KeyboardInterrupt: + print("\n👋 监控已停止") + break + except Exception as e: + print(f"❌ 监控过程中出错: {e}") + time.sleep(check_interval) + +def compare_two_checkpoints(ckpt1_path, ckpt2_path): + """比较两个特定的检查点""" + print(f"🔍 比较两个检查点:") + print(f" 检查点1: {ckpt1_path}") + print(f" 检查点2: {ckpt2_path}") + print("=" * 80) + + # 加载检查点 + state_dict1 = load_checkpoint(ckpt1_path) + state_dict2 = load_checkpoint(ckpt2_path) + + if state_dict1 is None or state_dict2 is None: + print("❌ 无法加载检查点文件") + return + + # 比较参数 + updated_params, unchanged_params = compare_parameters(state_dict1, state_dict2) + + if updated_params is None: + print("❌ 参数比较失败") + return + + # 分类显示结果 + updated_categories = categorize_parameters(updated_params) + unchanged_categories = categorize_parameters(unchanged_params) + + print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):") + for category_name, params in updated_categories.items(): + print_category_summary(category_name.replace('_', ' ').title(), params, '🔥') + + print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):") + for category_name, params in unchanged_categories.items(): + print_category_summary(category_name.replace('_', ' ').title(), params, '❄️') + + # 计算更新率 + total_params = len(updated_params) + len(unchanged_params) + update_rate = len(updated_params) / total_params * 100 + print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="检查模型参数更新情况") + parser.add_argument("--checkpoint_dir", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe", + help="检查点目录路径") + parser.add_argument("--compare", default=True, + help="比较两个特定检查点,而不是监控") + parser.add_argument("--ckpt1", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step1500_origin_cam_4.ckpt") + parser.add_argument("--ckpt2", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step500_origin_cam_4.ckpt") + parser.add_argument("--interval", type=int, default=60, + help="监控检查间隔(秒)") + parser.add_argument("--threshold", type=float, default=1e-8, + help="参数变化阈值") + + args = parser.parse_args() + + if args.compare: + if not args.ckpt1 or not args.ckpt2: + print("❌ 比较模式需要指定 --ckpt1 和 --ckpt2") + else: + compare_two_checkpoints(args.ckpt1, args.ckpt2) + else: + monitor_training(args.checkpoint_dir, args.interval) \ No newline at end of file diff --git a/scripts/decode_openx.py b/scripts/decode_openx.py new file mode 100644 index 0000000000000000000000000000000000000000..e068d6e121fecc46253561b60962bd9c95dfc641 --- /dev/null +++ b/scripts/decode_openx.py @@ -0,0 +1,428 @@ +import os +import torch +import numpy as np +from PIL import Image +import imageio +import argparse +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from tqdm import tqdm +import json + +class VideoDecoder: + def __init__(self, vae_path, device="cuda"): + """初始化视频解码器""" + self.device = device + + # 初始化模型管理器 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([vae_path]) + + # 创建pipeline并只保留VAE + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe = self.pipe.to(device) + + # 🔧 关键修复:确保VAE及其所有组件都在正确设备上 + self.pipe.vae = self.pipe.vae.to(device) + if hasattr(self.pipe.vae, 'model'): + self.pipe.vae.model = self.pipe.vae.model.to(device) + + print(f"✅ VAE解码器初始化完成,设备: {device}") + + def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + """ + 将latents解码为视频 - 修正版本,修复维度处理问题 + """ + print(f"🔧 开始解码latents...") + print(f"输入latents形状: {latents.shape}") + print(f"输入latents设备: {latents.device}") + print(f"输入latents数据类型: {latents.dtype}") + + # 确保latents有batch维度 + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 关键修正:确保latents在正确的设备上且数据类型匹配 + model_dtype = next(self.pipe.vae.parameters()).dtype + model_device = next(self.pipe.vae.parameters()).device + + print(f"模型设备: {model_device}") + print(f"模型数据类型: {model_dtype}") + + # 将latents移动到正确的设备和数据类型 + latents = latents.to(device=model_device, dtype=model_dtype) + + print(f"解码latents形状: {latents.shape}") + print(f"解码latents设备: {latents.device}") + print(f"解码latents数据类型: {latents.dtype}") + + # 🔧 强制设置pipeline设备,确保所有操作在同一设备上 + self.pipe.device = model_device + + # 使用VAE解码 + with torch.no_grad(): + try: + if tiled: + print("🔧 尝试tiled解码...") + decoded_video = self.pipe.decode_video( + latents, + tiled=True, + tile_size=tile_size, + tile_stride=tile_stride + ) + else: + print("🔧 使用非tiled解码...") + decoded_video = self.pipe.decode_video(latents, tiled=False) + + except Exception as e: + print(f"decode_video失败,错误: {e}") + import traceback + traceback.print_exc() + + # 🔧 fallback: 尝试直接调用VAE + try: + print("🔧 尝试直接调用VAE解码...") + decoded_video = self.pipe.vae.decode( + latents.squeeze(0), # 移除batch维度 [C, T, H, W] + device=model_device, + tiled=False + ) + # 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C] + if len(decoded_video.shape) == 4: # [T, H, W, C] + decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C] + except Exception as e2: + print(f"直接VAE解码也失败: {e2}") + raise e2 + + print(f"解码后视频形状: {decoded_video.shape}") + + # 🔧 关键修正:正确处理维度顺序 + video_np = None + + if len(decoded_video.shape) == 5: + # 检查不同的可能维度顺序 + if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]): + # 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C] + print("🔧 检测到格式: [B, C, T, H, W]") + video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C] + elif decoded_video.shape[1] == 3: + # 如果第二个维度是3,可能是 [B, C, T, H, W] + print("🔧 检测到可能的格式: [B, C, T, H, W]") + video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C] + elif decoded_video.shape[-1] == 3: + # 如果最后一个维度是3,可能是 [B, T, H, W, C] + print("🔧 检测到格式: [B, T, H, W, C]") + video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C] + else: + # 尝试找到维度为3的位置 + shape = list(decoded_video.shape) + if 3 in shape: + channel_dim = shape.index(3) + print(f"🔧 检测到通道维度在位置: {channel_dim}") + + if channel_dim == 1: # [B, C, T, H, W] + video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() + elif channel_dim == 4: # [B, T, H, W, C] + video_np = decoded_video[0].to(torch.float32).cpu().numpy() + else: + print(f"⚠️ 未知的通道维度位置: {channel_dim}") + raise ValueError(f"Cannot handle channel dimension at position {channel_dim}") + else: + print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}") + raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}") + + elif len(decoded_video.shape) == 4: + # 4维张量,检查可能的格式 + if decoded_video.shape[-1] == 3: # [T, H, W, C] + video_np = decoded_video.to(torch.float32).cpu().numpy() + elif decoded_video.shape[0] == 3: # [C, T, H, W] + video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() + else: + print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}") + raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}") + else: + print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}") + raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}") + + if video_np is None: + raise ValueError("Failed to convert video tensor to numpy array") + + print(f"转换后视频数组形状: {video_np.shape}") + + # 🔧 验证最终形状 + if len(video_np.shape) != 4: + raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}") + + if video_np.shape[-1] != 3: + print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}") + print(f"完整形状: {video_np.shape}") + # 尝试其他维度排列 + if video_np.shape[0] == 3: # [C, T, H, W] + print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]") + video_np = np.transpose(video_np, (1, 2, 3, 0)) + elif video_np.shape[1] == 3: # [T, C, H, W] + print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]") + video_np = np.transpose(video_np, (0, 2, 3, 1)) + else: + raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels") + + # 反归一化 + video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化 + video_np = (video_np * 255).astype(np.uint8) + + print(f"最终视频数组形状: {video_np.shape}") + print(f"视频数组值范围: {video_np.min()} - {video_np.max()}") + + # 保存视频 + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + with imageio.get_writer(output_path, fps=10, quality=8) as writer: + for frame_idx, frame in enumerate(video_np): + # 🔧 验证每一帧的形状 + if len(frame.shape) != 3 or frame.shape[-1] != 3: + print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}") + continue + + writer.append_data(frame) + if frame_idx % 10 == 0: + print(f" 写入帧 {frame_idx}/{len(video_np)}") + except Exception as e: + print(f"保存视频失败: {e}") + # 🔧 尝试保存前几帧为图片进行调试 + debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames") + os.makedirs(debug_dir, exist_ok=True) + + for i in range(min(5, len(video_np))): + frame = video_np[i] + debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png") + try: + if len(frame.shape) == 3 and frame.shape[-1] == 3: + Image.fromarray(frame).save(debug_path) + print(f"调试: 保存帧 {i} 到 {debug_path}") + else: + print(f"调试: 帧 {i} 形状异常: {frame.shape}") + except Exception as e2: + print(f"调试: 保存帧 {i} 失败: {e2}") + raise e + + print(f"✅ 视频保存到: {output_path}") + return video_np + + def save_frames_as_images(self, video_np, output_dir, prefix="frame"): + """将视频帧保存为单独的图像文件""" + os.makedirs(output_dir, exist_ok=True) + + for i, frame in enumerate(video_np): + frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png") + # 🔧 验证帧形状 + if len(frame.shape) == 3 and frame.shape[-1] == 3: + Image.fromarray(frame).save(frame_path) + else: + print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}") + + print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}") + +def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"): + """解码单个episode的编码数据 - 修正版本""" + print(f"\n🔧 解码episode: {encoded_pth_path}") + + # 加载编码数据 + try: + encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu") + print(f"✅ 成功加载编码数据") + except Exception as e: + print(f"❌ 加载编码数据失败: {e}") + return False + + # 检查数据结构 + print("🔍 编码数据结构:") + for key, value in encoded_data.items(): + if isinstance(value, torch.Tensor): + print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}") + elif isinstance(value, dict): + print(f" - {key}: dict with keys {list(value.keys())}") + else: + print(f" - {key}: {type(value)}") + + # 获取latents + latents = encoded_data.get('latents') + if latents is None: + print("❌ 未找到latents数据") + return False + + # 🔧 确保latents在CPU上(加载时的默认状态) + if latents.device != torch.device('cpu'): + latents = latents.cpu() + print(f"🔧 将latents移动到CPU: {latents.device}") + + episode_info = encoded_data.get('episode_info', {}) + episode_idx = episode_info.get('episode_idx', 'unknown') + total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数 + + print(f"Episode信息:") + print(f" - Episode索引: {episode_idx}") + print(f" - Latents形状: {latents.shape}") + print(f" - Latents设备: {latents.device}") + print(f" - Latents数据类型: {latents.dtype}") + print(f" - 原始总帧数: {total_frames}") + print(f" - 压缩后帧数: {latents.shape[1]}") + + # 创建输出目录 + episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}" + output_dir = os.path.join(output_base_dir, episode_name) + os.makedirs(output_dir, exist_ok=True) + + # 初始化解码器 + try: + decoder = VideoDecoder(vae_path, device) + except Exception as e: + print(f"❌ 初始化解码器失败: {e}") + return False + + # 解码为视频 + video_output_path = os.path.join(output_dir, "decoded_video.mp4") + try: + video_np = decoder.decode_latents_to_video( + latents, + video_output_path, + tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性 + tile_size=(34, 34), + tile_stride=(18, 16) + ) + + # 保存前几帧为图像(用于快速检查) + frames_dir = os.path.join(output_dir, "frames") + sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧 + decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}") + + # 保存解码信息 + decode_info = { + "source_pth": encoded_pth_path, + "decoded_video_path": video_output_path, + "latents_shape": list(latents.shape), + "decoded_video_shape": list(video_np.shape), + "original_total_frames": total_frames, + "decoded_frames": len(video_np), + "compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0, + "latents_dtype": str(latents.dtype), + "latents_device": str(latents.device), + "vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0 + } + + info_path = os.path.join(output_dir, "decode_info.json") + with open(info_path, 'w') as f: + json.dump(decode_info, f, indent=2) + + print(f"✅ Episode {episode_idx} 解码完成") + print(f" - 原始帧数: {total_frames}") + print(f" - 解码帧数: {len(video_np)}") + print(f" - 压缩比: {decode_info['compression_ratio']:.2f}") + print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}") + return True + + except Exception as e: + print(f"❌ 解码失败: {e}") + import traceback + traceback.print_exc() + return False + +def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"): + """批量解码episodes""" + print(f"🔧 批量解码Open-X episodes") + print(f"源目录: {encoded_base_dir}") + print(f"输出目录: {output_base_dir}") + + # 查找所有编码的episodes + episode_dirs = [] + if os.path.exists(encoded_base_dir): + for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性 + episode_dir = os.path.join(encoded_base_dir, item) + if os.path.isdir(episode_dir): + encoded_path = os.path.join(episode_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + episode_dirs.append(encoded_path) + + print(f"找到 {len(episode_dirs)} 个编码的episodes") + + if max_episodes and len(episode_dirs) > max_episodes: + episode_dirs = episode_dirs[:max_episodes] + print(f"限制处理前 {max_episodes} 个episodes") + + # 批量解码 + success_count = 0 + for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")): + print(f"\n{'='*60}") + print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}") + + success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device) + if success: + success_count += 1 + + print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)") + + print(f"\n🎉 批量解码完成!") + print(f"总处理: {len(episode_dirs)} 个episodes") + print(f"成功解码: {success_count} 个episodes") + print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%") + +def main(): + parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本") + parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch", + help="解码模式:single (单个episode) 或 batch (批量)") + parser.add_argument("--encoded_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth", + help="单个编码文件路径(single模式)") + parser.add_argument("--encoded_base_dir", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", + help="编码数据基础目录(batch模式)") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + help="VAE模型路径") + parser.add_argument("--output_dir", type=str, + default="./decoded_results_fixed", + help="解码输出目录") + parser.add_argument("--max_episodes", type=int, default=5, + help="最大解码episodes数量(batch模式,用于测试)") + parser.add_argument("--device", type=str, default="cuda", + help="计算设备") + + args = parser.parse_args() + + print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)") + print(f"模式: {args.mode}") + print(f"VAE路径: {args.vae_path}") + print(f"输出目录: {args.output_dir}") + print(f"设备: {args.device}") + + # 🔧 检查CUDA可用性 + if args.device == "cuda" and not torch.cuda.is_available(): + print("⚠️ CUDA不可用,切换到CPU") + args.device = "cpu" + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + if args.mode == "single": + print(f"输入文件: {args.encoded_pth}") + if not os.path.exists(args.encoded_pth): + print(f"❌ 输入文件不存在: {args.encoded_pth}") + return + + success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device) + if success: + print("✅ 单个episode解码成功") + else: + print("❌ 单个episode解码失败") + + elif args.mode == "batch": + print(f"输入目录: {args.encoded_base_dir}") + print(f"最大episodes: {args.max_episodes}") + + if not os.path.exists(args.encoded_base_dir): + print(f"❌ 输入目录不存在: {args.encoded_base_dir}") + return + + batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/download_recam.py b/scripts/download_recam.py new file mode 100644 index 0000000000000000000000000000000000000000..ff43e20030ba114eac1ec929c166be856407e1a9 --- /dev/null +++ b/scripts/download_recam.py @@ -0,0 +1,7 @@ +from huggingface_hub import snapshot_download + +snapshot_download( + repo_id="KwaiVGI/ReCamMaster-Wan2.1", + local_dir="models/ReCamMaster/checkpoints", + resume_download=True # 支持断点续传 +) diff --git a/scripts/download_wan2.1.py b/scripts/download_wan2.1.py new file mode 100644 index 0000000000000000000000000000000000000000..158be932267ed475d4d5978e37db6f44d14f5e28 --- /dev/null +++ b/scripts/download_wan2.1.py @@ -0,0 +1,5 @@ +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") \ No newline at end of file diff --git a/scripts/encode_dynamic_videos.py b/scripts/encode_dynamic_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..116084bf8951da8f3716773aa024822dc854fd88 --- /dev/null +++ b/scripts/encode_dynamic_videos.py @@ -0,0 +1,141 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +from tqdm import tqdm +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path): + """编码所有场景的视频""" + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + + for idx, scene_name in enumerate(tqdm(os.listdir(scenes_path))): + if idx < 450: + continue + scene_dir = os.path.join(scenes_path, scene_name) + if not os.path.isdir(scene_dir): + continue + + # 检查是否已编码 + encoded_path = os.path.join(scene_dir, "encoded_video-480p-1.pth") + if os.path.exists(encoded_path): + print(f"Scene {scene_name} already encoded, skipping...") + continue + + # 加载场景信息 + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if not os.path.exists(scene_info_path): + continue + + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + + # 加载视频 + video_path = os.path.join(scene_dir, scene_info['video_path']) + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + try: + print(f"Encoding scene {scene_name}...") + + # 加载和编码视频 + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + + # 编码视频 + with torch.no_grad(): + latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + # print(latents.shape) + # assert False + # 编码文本 + # prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera") + if processed_count == 0: + print('encode prompt!!!') + prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera") + del encoder.pipe.prompter + + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + "image_emb": {} + } + + torch.save(encoded_data, encoded_path) + print(f"Saved encoded data: {encoded_path}") + processed_count += 1 + + except Exception as e: + print(f"Error encoding scene {scene_name}: {e}") + continue + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path) diff --git a/scripts/encode_openx.py b/scripts/encode_openx.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d7fe9c4177d8cf96ced35501a84bfba3f8dc32 --- /dev/null +++ b/scripts/encode_openx.py @@ -0,0 +1,466 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +from tqdm import tqdm + +# 🔧 关键修复:设置环境变量避免GCS连接 +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["TFDS_DISABLE_GCS"] = "1" + +import tensorflow_datasets as tfds +import tensorflow as tf + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image, target_width=832, target_height=480): + """调整图像尺寸""" + image = v2.functional.resize( + image, + (target_height, target_width), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_episode_frames(self, episode_data, max_frames=300): + """🔧 从fractal数据集加载视频帧 - 基于实际observation字段优化""" + frames = [] + + steps = episode_data['steps'] + frame_count = 0 + + print(f"开始提取帧,最多 {max_frames} 帧...") + + for step_idx, step in enumerate(steps): + if frame_count >= max_frames: + break + + try: + obs = step['observation'] + + # 🔧 基于实际的observation字段,优先使用'image' + img_data = None + image_keys_to_try = [ + 'image', # ✅ 确认存在的主要图像字段 + 'rgb', # 备用RGB图像 + 'camera_image', # 备用相机图像 + 'exterior_image_1_left', # 可能的外部摄像头 + 'wrist_image', # 可能的手腕摄像头 + ] + + for img_key in image_keys_to_try: + if img_key in obs: + try: + img_tensor = obs[img_key] + img_data = img_tensor.numpy() + if step_idx < 3: # 只为前几个步骤打印 + print(f"✅ 找到图像字段: {img_key}, 形状: {img_data.shape}") + break + except Exception as e: + if step_idx < 3: + print(f"尝试字段 {img_key} 失败: {e}") + continue + + if img_data is not None: + # 确保图像数据格式正确 + if len(img_data.shape) == 3: # [H, W, C] + if img_data.dtype == np.uint8: + frame = Image.fromarray(img_data) + else: + # 如果是归一化的浮点数,转换为uint8 + if img_data.max() <= 1.0: + img_data = (img_data * 255).astype(np.uint8) + else: + img_data = img_data.astype(np.uint8) + frame = Image.fromarray(img_data) + + # 转换为RGB如果需要 + if frame.mode != 'RGB': + frame = frame.convert('RGB') + + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + frame_count += 1 + + if frame_count % 50 == 0: + print(f"已处理 {frame_count} 帧") + else: + if step_idx < 5: + print(f"步骤 {step_idx}: 图像形状不正确 {img_data.shape}") + else: + # 如果找不到图像,打印可用的观测键 + if step_idx < 5: # 只为前几个步骤打印 + available_keys = list(obs.keys()) + print(f"步骤 {step_idx}: 未找到图像,可用键: {available_keys}") + + except Exception as e: + print(f"处理步骤 {step_idx} 时出错: {e}") + continue + + print(f"成功提取 {len(frames)} 帧") + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + + def extract_camera_poses(self, episode_data, num_frames): + """🔧 从fractal数据集提取相机位姿信息 - 基于实际observation和action字段优化""" + camera_poses = [] + + steps = episode_data['steps'] + frame_count = 0 + + print("提取相机位姿信息...") + + # 🔧 累积位姿信息 + cumulative_translation = np.array([0.0, 0.0, 0.0], dtype=np.float32) + cumulative_rotation = np.array([0.0, 0.0, 0.0], dtype=np.float32) # 欧拉角 + + for step_idx, step in enumerate(steps): + if frame_count >= num_frames: + break + + try: + obs = step['observation'] + action = step.get('action', {}) + + # 🔧 基于实际的字段提取位姿变化 + pose_data = {} + found_pose = False + + # 1. 优先使用action中的world_vector(世界坐标系中的位移) + if 'world_vector' in action: + try: + world_vector = action['world_vector'].numpy() + if len(world_vector) == 3: + # 累积世界坐标位移 + cumulative_translation += world_vector + pose_data['translation'] = cumulative_translation.copy() + found_pose = True + + if step_idx < 3: + print(f"使用action.world_vector: {world_vector}, 累积位移: {cumulative_translation}") + except Exception as e: + if step_idx < 3: + print(f"action.world_vector提取失败: {e}") + + # 2. 使用action中的rotation_delta(旋转变化) + if 'rotation_delta' in action: + try: + rotation_delta = action['rotation_delta'].numpy() + if len(rotation_delta) == 3: + # 累积旋转变化 + cumulative_rotation += rotation_delta + + # 转换为四元数(简化版本) + euler_angles = cumulative_rotation + # 欧拉角转四元数(ZYX顺序) + roll, pitch, yaw = euler_angles[0], euler_angles[1], euler_angles[2] + + # 简化的欧拉角到四元数转换 + cy = np.cos(yaw * 0.5) + sy = np.sin(yaw * 0.5) + cp = np.cos(pitch * 0.5) + sp = np.sin(pitch * 0.5) + cr = np.cos(roll * 0.5) + sr = np.sin(roll * 0.5) + + qw = cr * cp * cy + sr * sp * sy + qx = sr * cp * cy - cr * sp * sy + qy = cr * sp * cy + sr * cp * sy + qz = cr * cp * sy - sr * sp * cy + + pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32) + found_pose = True + + if step_idx < 3: + print(f"使用action.rotation_delta: {rotation_delta}, 累积旋转: {cumulative_rotation}") + except Exception as e: + if step_idx < 3: + print(f"action.rotation_delta提取失败: {e}") + + # 确保rotation字段存在 + if 'rotation' not in pose_data: + # 使用当前累积的旋转计算四元数 + roll, pitch, yaw = cumulative_rotation[0], cumulative_rotation[1], cumulative_rotation[2] + + cy = np.cos(yaw * 0.5) + sy = np.sin(yaw * 0.5) + cp = np.cos(pitch * 0.5) + sp = np.sin(pitch * 0.5) + cr = np.cos(roll * 0.5) + sr = np.sin(roll * 0.5) + + qw = cr * cp * cy + sr * sp * sy + qx = sr * cp * cy - cr * sp * sy + qy = cr * sp * cy + sr * cp * sy + qz = cr * cp * sy - sr * sp * cy + + pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32) + + camera_poses.append(pose_data) + frame_count += 1 + + except Exception as e: + print(f"提取位姿步骤 {step_idx} 时出错: {e}") + # 添加默认位姿 + pose_data = { + 'translation': cumulative_translation.copy(), + 'rotation': np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + } + camera_poses.append(pose_data) + frame_count += 1 + + print(f"提取了 {len(camera_poses)} 个位姿") + print(f"最终累积位移: {cumulative_translation}") + print(f"最终累积旋转: {cumulative_rotation}") + + return camera_poses + + def create_camera_matrices(self, camera_poses): + """将位姿转换为4x4变换矩阵""" + matrices = [] + + for pose in camera_poses: + matrix = np.eye(4, dtype=np.float32) + + # 设置平移 + matrix[:3, 3] = pose['translation'] + + # 设置旋转 - 假设是四元数 [w, x, y, z] + if len(pose['rotation']) == 4: + # 四元数转旋转矩阵 + q = pose['rotation'] + w, x, y, z = q[0], q[1], q[2], q[3] + + # 四元数到旋转矩阵的转换 + matrix[0, 0] = 1 - 2*(y*y + z*z) + matrix[0, 1] = 2*(x*y - w*z) + matrix[0, 2] = 2*(x*z + w*y) + matrix[1, 0] = 2*(x*y + w*z) + matrix[1, 1] = 1 - 2*(x*x + z*z) + matrix[1, 2] = 2*(y*z - w*x) + matrix[2, 0] = 2*(x*z - w*y) + matrix[2, 1] = 2*(y*z + w*x) + matrix[2, 2] = 1 - 2*(x*x + y*y) + elif len(pose['rotation']) == 3: + # 欧拉角转换(如果需要) + pass + + matrices.append(matrix) + + return np.array(matrices) + +def encode_fractal_dataset(dataset_path, text_encoder_path, vae_path, output_dir, max_episodes=None): + """🔧 编码fractal20220817_data数据集 - 基于实际字段结构优化""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + os.makedirs(output_dir, exist_ok=True) + + processed_count = 0 + prompt_emb = None + + try: + # 🔧 使用你提供的成功方法加载数据集 + ds = tfds.load( + "fractal20220817_data", + split="train", + data_dir=dataset_path, + ) + + print(f"✅ 成功加载fractal20220817_data数据集") + + # 限制处理的episode数量 + if max_episodes: + ds = ds.take(max_episodes) + print(f"限制处理episodes数量: {max_episodes}") + + except Exception as e: + print(f"❌ 加载数据集失败: {e}") + return + + for episode_idx, episode in enumerate(tqdm(ds, desc="处理episodes")): + try: + episode_name = f"episode_{episode_idx:06d}" + save_episode_dir = os.path.join(output_dir, episode_name) + + # 检查是否已经处理过 + encoded_path = os.path.join(save_episode_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + print(f"Episode {episode_name} 已处理,跳过...") + processed_count += 1 + continue + + os.makedirs(save_episode_dir, exist_ok=True) + + print(f"\n🔧 处理episode {episode_name}...") + + # 🔧 分析episode结构(仅对前几个episode) + if episode_idx < 2: + print("Episode结构分析:") + for key in episode.keys(): + print(f" - {key}: {type(episode[key])}") + + # 分析第一个step的结构 + steps = episode['steps'] + for step in steps.take(1): + print("第一个step结构:") + for key in step.keys(): + print(f" - {key}: {type(step[key])}") + + if 'observation' in step: + obs = step['observation'] + print(" observation键:") + print(f" 🔍 可用字段: {list(obs.keys())}") + + # 重点检查图像和位姿相关字段 + key_fields = ['image', 'vector_to_go', 'rotation_delta_to_go', 'base_pose_tool_reached'] + for key in key_fields: + if key in obs: + try: + value = obs[key] + if hasattr(value, 'shape'): + print(f" ✅ {key}: {type(value)}, shape: {value.shape}") + else: + print(f" ✅ {key}: {type(value)}") + except Exception as e: + print(f" ❌ {key}: 无法访问 ({e})") + + if 'action' in step: + action = step['action'] + print(" action键:") + print(f" 🔍 可用字段: {list(action.keys())}") + + # 重点检查位姿相关字段 + key_fields = ['world_vector', 'rotation_delta', 'base_displacement_vector'] + for key in key_fields: + if key in action: + try: + value = action[key] + if hasattr(value, 'shape'): + print(f" ✅ {key}: {type(value)}, shape: {value.shape}") + else: + print(f" ✅ {key}: {type(value)}") + except Exception as e: + print(f" ❌ {key}: 无法访问 ({e})") + + # 加载视频帧 + video_frames = encoder.load_episode_frames(episode) + if video_frames is None: + print(f"❌ 无法加载episode {episode_name}的视频帧") + continue + + print(f"✅ Episode {episode_name} 视频形状: {video_frames.shape}") + + # 提取相机位姿 + num_frames = video_frames.shape[1] + camera_poses = encoder.extract_camera_poses(episode, num_frames) + camera_matrices = encoder.create_camera_matrices(camera_poses) + + print(f"🔧 编码episode {episode_name}...") + + # 准备相机数据 + cam_emb = { + 'extrinsic': camera_matrices, + 'intrinsic': np.eye(3, dtype=np.float32) + } + + # 编码视频 + frames_batch = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + + with torch.no_grad(): + latents = encoder.pipe.encode_video(frames_batch, **encoder.tiler_kwargs)[0] + + # 编码文本prompt(第一次) + if prompt_emb is None: + print('🔧 编码prompt...') + prompt_emb = encoder.pipe.encode_prompt( + "A video of robotic manipulation task with camera movement" + ) + # 释放prompter以节省内存 + del encoder.pipe.prompter + + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in prompt_emb.items()}, + "cam_emb": cam_emb, + "episode_info": { + "episode_idx": episode_idx, + "total_frames": video_frames.shape[1], + "pose_extraction_method": "observation_action_based" + } + } + + torch.save(encoded_data, encoded_path) + print(f"✅ 保存编码数据: {encoded_path}") + + processed_count += 1 + print(f"✅ 已处理 {processed_count} 个episodes") + + except Exception as e: + print(f"❌ 处理episode {episode_idx}时出错: {e}") + import traceback + traceback.print_exc() + continue + + print(f"🎉 编码完成! 总共处理了 {processed_count} 个episodes") +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Encode Open-X Fractal20220817 Dataset - Based on Real Structure") + parser.add_argument("--dataset_path", type=str, + default="/share_zhuyixuan05/public_datasets/open-x/0.1.0", + help="Path to tensorflow_datasets directory") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + parser.add_argument("--output_dir", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded") + parser.add_argument("--max_episodes", type=int, default=10000, + help="Maximum number of episodes to process (default: 10 for testing)") + + args = parser.parse_args() + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + print("🚀 开始编码Open-X Fractal数据集 (基于实际字段结构)...") + print(f"📁 数据集路径: {args.dataset_path}") + print(f"💾 输出目录: {args.output_dir}") + print(f"🔢 最大处理episodes: {args.max_episodes}") + print("🔧 基于实际observation和action字段的位姿提取方法") + print("✅ 优先使用 'image' 字段获取图像数据") + + encode_fractal_dataset( + args.dataset_path, + args.text_encoder_path, + args.vae_path, + args.output_dir, + args.max_episodes + ) \ No newline at end of file diff --git a/scripts/encode_rlbench_video.py b/scripts/encode_rlbench_video.py new file mode 100644 index 0000000000000000000000000000000000000000..289b257e8c1d149370609cb32d1ac91345b55719 --- /dev/null +++ b/scripts/encode_rlbench_video.py @@ -0,0 +1,170 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 512 , 512 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + + for i, scene_name in enumerate(os.listdir(scenes_path)): + # if i < 1700: + # continue + scene_dir = os.path.join(scenes_path, scene_name) + for j, demo_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))): + demo_dir = os.path.join(scene_dir, demo_name) + for filename in os.listdir(demo_dir): + # 检查文件是否以.mp4结尾(不区分大小写) + if filename.lower().endswith('.mp4'): + # 获取完整路径 + full_path = os.path.join(demo_dir, filename) + print(full_path) + save_dir = os.path.join(output_dir,scene_name+'_'+demo_name) + # print('in:',scene_dir) + # print('out:',save_dir) + + + + os.makedirs(save_dir,exist_ok=True) + # 检查是否已编码 + encoded_path = os.path.join(save_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + print(f"Scene {scene_name} already encoded, skipping...") + continue + + # 加载场景信息 + + scene_cam_path = full_path.replace("side.mp4", "data.npy") + print(scene_cam_path) + if not os.path.exists(scene_cam_path): + continue + + # with np.load(scene_cam_path) as data: + cam_data = np.load(scene_cam_path) + cam_emb = cam_data + print(cam_data.shape) + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + # 加载视频 + video_path = full_path + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + # try: + print(f"Encoding scene {scene_name}...Demo {demo_name}") + + # 加载和编码视频 + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + print('video shape:',video_frames.shape) + # 编码视频 + with torch.no_grad(): + latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + # if processed_count == 0: + # print('encode prompt!!!') + # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking") + # del encoder.pipe.prompter + # pdb.set_trace() + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + "cam_emb": cam_emb + } + # pdb.set_trace() + torch.save(encoded_data, encoded_path) + print(f"Saved encoded data: {encoded_path}") + processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/RLBench") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/rlbench") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/encode_sekai_video.py b/scripts/encode_sekai_video.py new file mode 100644 index 0000000000000000000000000000000000000000..65f47dc9a185ff832b3cfd7da53686c81a63b302 --- /dev/null +++ b/scripts/encode_sekai_video.py @@ -0,0 +1,162 @@ +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + + for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))): + # if i < 1700: + # continue + scene_dir = os.path.join(scenes_path, scene_name) + save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + # print('in:',scene_dir) + # print('out:',save_dir) + + if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir): + continue + + + os.makedirs(save_dir,exist_ok=True) + # 检查是否已编码 + encoded_path = os.path.join(save_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + print(f"Scene {scene_name} already encoded, skipping...") + continue + + # 加载场景信息 + + scene_cam_path = scene_dir.replace(".mp4", ".npz") + if not os.path.exists(scene_cam_path): + continue + + with np.load(scene_cam_path) as data: + cam_data = data.files + cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + # 加载视频 + video_path = scene_dir + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + # try: + print(f"Encoding scene {scene_name}...") + + # 加载和编码视频 + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + print('video shape:',video_frames.shape) + # 编码视频 + with torch.no_grad(): + latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + if processed_count == 0: + print('encode prompt!!!') + prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking") + del encoder.pipe.prompter + # pdb.set_trace() + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + "cam_emb": cam_emb + } + # pdb.set_trace() + torch.save(encoded_data, encoded_path) + print(f"Saved encoded data: {encoded_path}") + processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/encode_sekai_walking.py b/scripts/encode_sekai_walking.py new file mode 100644 index 0000000000000000000000000000000000000000..5d23d961eb6ab29935970044b257bd64e633ca5c --- /dev/null +++ b/scripts/encode_sekai_walking.py @@ -0,0 +1,249 @@ + +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + + processed_chunk_count = 0 + + prompt_emb = 0 + + os.makedirs(output_dir,exist_ok=True) + chunk_size = 300 + for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))): + # print('index-----:',type(i)) + # if i < 3000 :#or i >=2000: + # # print('index-----:',i) + # continue + # print('index:',i) + print('index:',i) + scene_dir = os.path.join(scenes_path, scene_name) + + # save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + # print('in:',scene_dir) + # print('out:',save_dir) + + if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir): + continue + + + scene_cam_path = scene_dir.replace(".mp4", ".npz") + if not os.path.exists(scene_cam_path): + continue + + with np.load(scene_cam_path) as data: + cam_data = data.files + cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + video_name = scene_name[:-4].split('_')[0] + start_frame = int(scene_name[:-4].split('_')[1]) + end_frame = int(scene_name[:-4].split('_')[2]) + + sampled_range = range(start_frame, end_frame , chunk_size) + sampled_frames = list(sampled_range) + + sampled_chunk_end = sampled_frames[0] + 300 + start_str = f"{sampled_frames[0]:07d}" + end_str = f"{sampled_chunk_end:07d}" + + chunk_name = f"{video_name}_{start_str}_{end_str}" + save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth") + + if os.path.exists(save_chunk_path): + print(f"Video {video_name} already encoded, skipping...") + continue + + # 加载视频 + video_path = scene_dir + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + print('video shape:',video_frames.shape) + + + + # print(sampled_frames) + + print(f"Encoding scene {scene_name}...") + for sampled_chunk_start in sampled_frames: + sampled_chunk_end = sampled_chunk_start + 300 + start_str = f"{sampled_chunk_start:07d}" + end_str = f"{sampled_chunk_end:07d}" + + # 生成保存目录名(假设video_name已定义) + chunk_name = f"{video_name}_{start_str}_{end_str}" + save_chunk_dir = os.path.join(output_dir,chunk_name) + + os.makedirs(save_chunk_dir,exist_ok=True) + print(f"Encoding chunk {chunk_name}...") + + encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth") + + if os.path.exists(encoded_path): + print(f"Chunk {chunk_name} already encoded, skipping...") + continue + + + chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...] + # print('extrinsic:',cam_emb['extrinsic'].shape) + chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame], + 'intrinsic':cam_emb['intrinsic']} + + # print('chunk shape:',chunk_frames.shape) + + with torch.no_grad(): + latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0] + + # 编码文本 + # if processed_count == 0: + # print('encode prompt!!!') + # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking") + # del encoder.pipe.prompter + # pdb.set_trace() + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + # "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + "cam_emb": chunk_cam_emb + } + # pdb.set_trace() + torch.save(encoded_data, encoded_path) + print(f"Saved encoded data: {encoded_path}") + processed_chunk_count += 1 + + processed_count += 1 + + print("Encoded scene numebr:",processed_count) + print("Encoded chunk numebr:",processed_chunk_count) + + # os.makedirs(save_dir,exist_ok=True) + # # 检查是否已编码 + # encoded_path = os.path.join(save_dir, "encoded_video.pth") + # if os.path.exists(encoded_path): + # print(f"Scene {scene_name} already encoded, skipping...") + # continue + + # 加载场景信息 + + + + # try: + # print(f"Encoding scene {scene_name}...") + + # 加载和编码视频 + + # 编码视频 + # with torch.no_grad(): + # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # # 编码文本 + # if processed_count == 0: + # print('encode prompt!!!') + # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking") + # del encoder.pipe.prompter + # # pdb.set_trace() + # # 保存编码结果 + # encoded_data = { + # "latents": latents.cpu(), + # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + # "cam_emb": cam_emb + # } + # # pdb.set_trace() + # torch.save(encoded_data, encoded_path) + # print(f"Saved encoded data: {encoded_path}") + # processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/encode_spatialvid.py b/scripts/encode_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..c41e3181bf6e5162a315bc98f5528171d384777f --- /dev/null +++ b/scripts/encode_spatialvid.py @@ -0,0 +1,409 @@ + +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm +import pandas as pd + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +from scipy.spatial.transform import Slerp +from scipy.spatial.transform import Rotation as R + +def interpolate_camera_poses(original_frames, original_poses, target_frames): + """ + 对相机姿态进行插值,生成目标帧对应的姿态参数 + + 参数: + original_frames: 原始帧索引列表,如[0,6,12,...] + original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw] + target_frames: 目标帧索引列表,如[0,4,8,12,...] + + 返回: + target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量 + """ + # 确保输入有效 + print('original_frames:',len(original_frames)) + print('original_poses:',len(original_poses)) + if len(original_frames) != len(original_poses): + raise ValueError("原始帧数量与姿态数量不匹配") + + if original_poses.shape[1] != 7: + raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}") + + target_poses = [] + + # 提取旋转部分并转换为Rotation对象 + rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分 + + for t in target_frames: + # 找到t前后的原始帧索引 + idx = np.searchsorted(original_frames, t, side='left') + + # 处理边界情况 + if idx == 0: + # 使用第一个姿态 + target_poses.append(original_poses[0]) + continue + if idx >= len(original_frames): + # 使用最后一个姿态 + target_poses.append(original_poses[-1]) + continue + + # 获取前后帧的信息 + t_prev, t_next = original_frames[idx-1], original_frames[idx] + pose_prev, pose_next = original_poses[idx-1], original_poses[idx] + + # 计算插值权重 + alpha = (t - t_prev) / (t_next - t_prev) + + # 1. 平移向量的线性插值 + translation_prev = pose_prev[:3] + translation_next = pose_next[:3] + interpolated_translation = translation_prev + alpha * (translation_next - translation_prev) + + # 2. 旋转四元数的球面线性插值(SLERP) + # 创建Slerp对象 + slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1]) + interpolated_rotation = slerp(t) + + # 组合平移和旋转 + interpolated_pose = np.concatenate([ + interpolated_translation, + interpolated_rotation.as_quat() # 转换回四元数 + ]) + + target_poses.append(interpolated_pose) + + return np.array(target_poses) + + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(900, 1600)), + # v2.Resize(size=(900, 1600), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + # print(width,height) + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + """加载完整视频""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + + processed_chunk_count = 0 + + prompt_emb = 0 + + metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv') + + + os.makedirs(output_dir,exist_ok=True) + chunk_size = 300 + required_keys = ["latents", "cam_emb", "prompt_emb"] + + for i, scene_name in enumerate(os.listdir(scenes_path)): + # print('index-----:',type(i)) + if i < 3 :#or i >=2000: + # # print('index-----:',i) + continue + # print('index:',i) + print('group:',i) + scene_dir = os.path.join(scenes_path, scene_name) + + # save_dir = os.path.join(output_dir,scene_name.split('.')[0]) + print('in:',scene_dir) + # print('out:',save_dir) + for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))): + + # if j < 1000 :#or i >=2000: + # print('index:',j) + # continue + print(video_name) + video_path = os.path.join(scene_dir, video_name) + if not video_path.endswith(".mp4"):# or os.path.isdir(output_dir): + continue + + video_info = metadata[metadata['id'] == video_name[:-4]] + num_frames = video_info['num frames'].iloc[0] + + scene_cam_dir = video_path.replace( "videos","annotations")[:-4] + scene_cam_path = os.path.join(scene_cam_dir,'poses.npy') + + scene_caption_path = os.path.join(scene_cam_dir,'caption.json') + + with open(scene_caption_path, 'r', encoding='utf-8') as f: + caption_data = json.load(f) + caption = caption_data["SceneSummary"] + if not os.path.exists(scene_cam_path): + print(f"Pose not found: {scene_cam_path}") + continue + + camera_poses = np.load(scene_cam_path) + cam_data_len = camera_poses.shape[0] + + # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data} + # with open(scene_cam_path, 'rb') as f: + # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用 + + # 加载视频 + # video_path = scene_dir + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + start_str = f"{0:07d}" + end_str = f"{chunk_size:07d}" + chunk_name = f"{video_name[:-4]}_{start_str}_{end_str}" + first_save_chunk_dir = os.path.join(output_dir,chunk_name) + + first_chunk_encoded_path = os.path.join(first_save_chunk_dir, "encoded_video.pth") + # print(first_chunk_encoded_path) + if os.path.exists(first_chunk_encoded_path): + data = torch.load(first_chunk_encoded_path,weights_only=False) + if 'latents' in data: + video_frames = 1 + else: + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + print('video shape:',video_frames.shape) + + + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + print('video shape:',video_frames.shape) + + video_name = video_name[:-4].split('_')[0] + start_frame = 0 + end_frame = num_frames + # print("num_frames:",num_frames) + + cam_interval = end_frame // (cam_data_len - 1) + + cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True) + cam_frames = np.round(cam_frames).astype(int) + cam_frames = cam_frames.tolist() + # list(range(0, end_frame + 1 , cam_interval)) + + + sampled_range = range(start_frame, end_frame , chunk_size) + sampled_frames = list(sampled_range) + + sampled_chunk_end = sampled_frames[0] + chunk_size + start_str = f"{sampled_frames[0]:07d}" + end_str = f"{sampled_chunk_end:07d}" + + chunk_name = f"{video_name}_{start_str}_{end_str}" + # save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth") + + # if os.path.exists(save_chunk_path): + # print(f"Video {video_name} already encoded, skipping...") + # continue + + + + + + # print(sampled_frames) + + print(f"Encoding scene {video_name}...") + chunk_count_in_one_video = 0 + for sampled_chunk_start in sampled_frames: + if num_frames - sampled_chunk_start < 100: + continue + sampled_chunk_end = sampled_chunk_start + chunk_size + start_str = f"{sampled_chunk_start:07d}" + end_str = f"{sampled_chunk_end:07d}" + + resample_cam_frame = list(range(sampled_chunk_start, sampled_chunk_end , 4)) + + # 生成保存目录名(假设video_name已定义) + chunk_name = f"{video_name}_{start_str}_{end_str}" + save_chunk_dir = os.path.join(output_dir,chunk_name) + + os.makedirs(save_chunk_dir,exist_ok=True) + print(f"Encoding chunk {chunk_name}...") + + encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth") + + missing_keys = required_keys + if os.path.exists(encoded_path): + print('error:',encoded_path) + data = torch.load(encoded_path,weights_only=False) + missing_keys = [key for key in required_keys if key not in data] + # print(missing_keys) + # print(f"Chunk {chunk_name} already encoded, skipping...") + if missing_keys: + print(f"警告: 文件中缺少以下必要元素: {missing_keys}") + if len(missing_keys) == 0 : + continue + else: + print(f"警告: 缺少pth文件: {encoded_path}") + if not isinstance(video_frames, torch.Tensor): + + video_frames = encoder.load_video_frames(video_path) + if video_frames is None: + print(f"Failed to load video: {video_path}") + continue + + video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16) + + print('video shape:',video_frames.shape) + if "latents" in missing_keys: + chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...] + + # print('extrinsic:',cam_emb['extrinsic'].shape) + + # chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame], + # 'intrinsic':cam_emb['intrinsic']} + + # print('chunk shape:',chunk_frames.shape) + + with torch.no_grad(): + latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0] + else: + latents = data['latents'] + if "cam_emb" in missing_keys: + cam_emb = interpolate_camera_poses(cam_frames, camera_poses,resample_cam_frame) + chunk_cam_emb ={'extrinsic':cam_emb} + print(f"视频长度:{chunk_size},重采样相机长度:{cam_emb.shape[0]}") + else: + chunk_cam_emb = data['cam_emb'] + + if "prompt_emb" in missing_keys: + # 编码文本 + if chunk_count_in_one_video == 0: + print(caption) + with torch.no_grad(): + prompt_emb = encoder.pipe.encode_prompt(caption) + else: + prompt_emb = data['prompt_emb'] + + # del encoder.pipe.prompter + # pdb.set_trace() + # 保存编码结果 + encoded_data = { + "latents": latents.cpu(), + "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + "cam_emb": chunk_cam_emb + } + # pdb.set_trace() + torch.save(encoded_data, encoded_path) + print(f"Saved encoded data: {encoded_path}") + processed_chunk_count += 1 + chunk_count_in_one_video += 1 + + processed_count += 1 + + print("Encoded scene numebr:",processed_count) + print("Encoded chunk numebr:",processed_chunk_count) + + # os.makedirs(save_dir,exist_ok=True) + # # 检查是否已编码 + # encoded_path = os.path.join(save_dir, "encoded_video.pth") + # if os.path.exists(encoded_path): + # print(f"Scene {scene_name} already encoded, skipping...") + # continue + + # 加载场景信息 + + + + # try: + # print(f"Encoding scene {scene_name}...") + + # 加载和编码视频 + + # 编码视频 + # with torch.no_grad(): + # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0] + + # # 编码文本 + # if processed_count == 0: + # print('encode prompt!!!') + # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking") + # del encoder.pipe.prompter + # # pdb.set_trace() + # # 保存编码结果 + # encoded_data = { + # "latents": latents.cpu(), + # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()}, + # "cam_emb": cam_emb + # } + # # pdb.set_trace() + # torch.save(encoded_data, encoded_path) + # print(f"Saved encoded data: {encoded_path}") + # processed_count += 1 + + # except Exception as e: + # print(f"Error encoding scene {scene_name}: {e}") + # continue + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/encode_spatialvid_first_frame.py b/scripts/encode_spatialvid_first_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..00413b06493c7d349243d18b80333117252fd392 --- /dev/null +++ b/scripts/encode_spatialvid_first_frame.py @@ -0,0 +1,285 @@ + +import os +import torch +import lightning as pl +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import imageio +from torchvision.transforms import v2 +from einops import rearrange +import argparse +import numpy as np +import pdb +from tqdm import tqdm +import pandas as pd + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +from scipy.spatial.transform import Slerp +from scipy.spatial.transform import Rotation as R + +def interpolate_camera_poses(original_frames, original_poses, target_frames): + """ + 对相机姿态进行插值,生成目标帧对应的姿态参数 + + 参数: + original_frames: 原始帧索引列表,如[0,6,12,...] + original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw] + target_frames: 目标帧索引列表,如[0,4,8,12,...] + + 返回: + target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量 + """ + # 确保输入有效 + print('original_frames:',len(original_frames)) + print('original_poses:',len(original_poses)) + if len(original_frames) != len(original_poses): + raise ValueError("原始帧数量与姿态数量不匹配") + + if original_poses.shape[1] != 7: + raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}") + + target_poses = [] + + # 提取旋转部分并转换为Rotation对象 + rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分 + + for t in target_frames: + # 找到t前后的原始帧索引 + idx = np.searchsorted(original_frames, t, side='left') + + # 处理边界情况 + if idx == 0: + # 使用第一个姿态 + target_poses.append(original_poses[0]) + continue + if idx >= len(original_frames): + # 使用最后一个姿态 + target_poses.append(original_poses[-1]) + continue + + # 获取前后帧的信息 + t_prev, t_next = original_frames[idx-1], original_frames[idx] + pose_prev, pose_next = original_poses[idx-1], original_poses[idx] + + # 计算插值权重 + alpha = (t - t_prev) / (t_next - t_prev) + + # 1. 平移向量的线性插值 + translation_prev = pose_prev[:3] + translation_next = pose_next[:3] + interpolated_translation = translation_prev + alpha * (translation_next - translation_prev) + + # 2. 旋转四元数的球面线性插值(SLERP) + # 创建Slerp对象 + slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1]) + interpolated_rotation = slerp(t) + + # 组合平移和旋转 + interpolated_pose = np.concatenate([ + interpolated_translation, + interpolated_rotation.as_quat() # 转换回四元数 + ]) + + target_poses.append(interpolated_pose) + + return np.array(target_poses) + +class VideoEncoder(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + self.frame_process = v2.Compose([ + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + width_ori, height_ori_ = 832 , 480 + image = v2.functional.resize( + image, + (round(height_ori_), round(width_ori)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + def load_single_frame(self, video_path, frame_idx): + """只加载指定的单帧""" + reader = imageio.get_reader(video_path) + + try: + # 直接跳转到指定帧 + frame_data = reader.get_data(frame_idx) + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + + # 添加batch和time维度: [C, H, W] -> [1, C, 1, H, W] + frame = frame.unsqueeze(0).unsqueeze(2) + + except Exception as e: + print(f"Error loading frame {frame_idx} from {video_path}: {e}") + return None + finally: + reader.close() + + return frame + + def load_video_frames(self, video_path): + """加载完整视频(保留用于兼容性)""" + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir): + """编码所有场景的视频""" + + encoder = VideoEncoder(text_encoder_path, vae_path) + encoder = encoder.cuda() + encoder.pipe.device = "cuda" + + processed_count = 0 + processed_chunk_count = 0 + + metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv') + + os.makedirs(output_dir,exist_ok=True) + chunk_size = 300 + + for i, scene_name in enumerate(os.listdir(scenes_path)): + if i < 2: + continue + print('group:',i) + scene_dir = os.path.join(scenes_path, scene_name) + + print('in:',scene_dir) + for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))): + print(video_name) + video_path = os.path.join(scene_dir, video_name) + if not video_path.endswith(".mp4"): + continue + + video_info = metadata[metadata['id'] == video_name[:-4]] + num_frames = video_info['num frames'].iloc[0] + + scene_cam_dir = video_path.replace("videos","annotations")[:-4] + scene_cam_path = os.path.join(scene_cam_dir,'poses.npy') + scene_caption_path = os.path.join(scene_cam_dir,'caption.json') + + with open(scene_caption_path, 'r', encoding='utf-8') as f: + caption_data = json.load(f) + caption = caption_data["SceneSummary"] + + if not os.path.exists(scene_cam_path): + print(f"Pose not found: {scene_cam_path}") + continue + + camera_poses = np.load(scene_cam_path) + cam_data_len = camera_poses.shape[0] + + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + continue + + video_name = video_name[:-4].split('_')[0] + start_frame = 0 + end_frame = num_frames + + cam_interval = end_frame // (cam_data_len - 1) + + cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True) + cam_frames = np.round(cam_frames).astype(int) + cam_frames = cam_frames.tolist() + + sampled_range = range(start_frame, end_frame, chunk_size) + sampled_frames = list(sampled_range) + + print(f"Encoding scene {video_name}...") + chunk_count_in_one_video = 0 + + for sampled_chunk_start in sampled_frames: + if num_frames - sampled_chunk_start < 100: + continue + + sampled_chunk_end = sampled_chunk_start + chunk_size + start_str = f"{sampled_chunk_start:07d}" + end_str = f"{sampled_chunk_end:07d}" + + chunk_name = f"{video_name}_{start_str}_{end_str}" + save_chunk_dir = os.path.join(output_dir, chunk_name) + os.makedirs(save_chunk_dir, exist_ok=True) + + print(f"Encoding chunk {chunk_name}...") + + first_latent_path = os.path.join(save_chunk_dir, "first_latent.pth") + + if os.path.exists(first_latent_path): + print(f"First latent for chunk {chunk_name} already exists, skipping...") + continue + + # 只加载需要的那一帧 + first_frame_idx = sampled_chunk_start + print(f"first_frame:{first_frame_idx}") + first_frame = encoder.load_single_frame(video_path, first_frame_idx) + + if first_frame is None: + print(f"Failed to load frame {first_frame_idx} from: {video_path}") + continue + + first_frame = first_frame.to("cuda", dtype=torch.bfloat16) + + # 重复4次 + repeated_first_frame = first_frame.repeat(1, 1, 4, 1, 1) + print(f"Repeated first frame shape: {repeated_first_frame.shape}") + + with torch.no_grad(): + first_latents = encoder.pipe.encode_video(repeated_first_frame, **encoder.tiler_kwargs)[0] + + first_latent_data = { + "latents": first_latents.cpu(), + } + torch.save(first_latent_data, first_latent_path) + print(f"Saved first latent: {first_latent_path}") + + processed_chunk_count += 1 + chunk_count_in_one_video += 1 + + processed_count += 1 + print("Encoded scene number:", processed_count) + print("Encoded chunk number:", processed_chunk_count) + + print(f"Encoding completed! Processed {processed_count} scenes.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + + parser.add_argument("--output_dir",type=str, + default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + + args = parser.parse_args() + encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir) diff --git a/scripts/hud_logo.py b/scripts/hud_logo.py new file mode 100644 index 0000000000000000000000000000000000000000..df95aa1374450aa233d27d24ea9629f846fa3a1f --- /dev/null +++ b/scripts/hud_logo.py @@ -0,0 +1,40 @@ +from PIL import Image, ImageDraw, ImageFont +import os + +os.makedirs("wasd_ui", exist_ok=True) + +# UI sizes (small) +key_size = (48, 48) +corner = 10 +bg_padding = 6 +font = ImageFont.truetype("arial.ttf", 28) # 替换成本地支持的字体 + +def rounded_rect(im, bbox, radius, fill): + draw = ImageDraw.Draw(im, "RGBA") + draw.rounded_rectangle(bbox, radius=radius, fill=fill) + +# background plate +bg_width = key_size[0] * 3 + bg_padding * 4 +bg_height = key_size[1] * 2 + bg_padding * 4 +ui_bg = Image.new("RGBA", (bg_width, bg_height), (0,0,0,0)) +rounded_rect(ui_bg, (0,0,bg_width,bg_height), corner, (0,0,0,140)) +ui_bg.save("wasd_ui/ui_background.png") + +keys = ["W","A","S","D"] + +def draw_key(char, active): + im = Image.new("RGBA", key_size, (0,0,0,0)) + rounded_rect(im, (0,0,key_size[0],key_size[1]), corner, + (255,255,255,230) if active else (200,200,200,180)) + draw = ImageDraw.Draw(im) + color = (0,0,0) if active else (50,50,50) + w,h = draw.textsize(char, font=font) + draw.text(((key_size[0]-w)//2,(key_size[1]-h)//2), + char, font=font, fill=color) + return im + +for k in keys: + draw_key(k, False).save(f"wasd_ui/key_{k}_idle.png") + draw_key(k, True).save(f"wasd_ui/key_{k}_active.png") + +print("✅ WASD UI assets generated in ./wasd_ui/") \ No newline at end of file diff --git a/scripts/infer_demo.py b/scripts/infer_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eb44aa48a87dbf8e5c5cb24d1e1617a7a023bc --- /dev/null +++ b/scripts/infer_demo.py @@ -0,0 +1,1458 @@ +import os +import sys + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(ROOT_DIR) + +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import random +import copy +from datetime import datetime + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def replace_dit_model_in_manager(): + """替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + + +def add_moe_components(dit_model, moe_config): + """🔧 添加MoE相关组件 - 修正版本""" + if not hasattr(dit_model, 'moe_config'): + dit_model.moe_config = moe_config + print("✅ 添加了MoE配置到模型") + dit_model.top_k = moe_config.get("top_k", 1) + + # 为每个block动态添加MoE组件 + dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + unified_dim = moe_config.get("unified_dim", 25) + num_experts = moe_config.get("num_experts", 4) + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + dit_model.global_router = nn.Linear(unified_dim, num_experts) + + + for i, block in enumerate(dit_model.blocks): + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=moe_config.get("num_experts", 4), + top_k=moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") + + +def generate_sekai_camera_embeddings_sliding( + cam_data, + start_frame, + initial_condition_frames, + new_frames, + total_generated, + use_real_poses=True, + direction="left"): + """ + 为Sekai数据集生成camera embeddings - 滑动窗口版本 + + Args: + cam_data: 包含Sekai相机外参的字典, 键'extrinsic'对应一个N*4*4的numpy数组 + start_frame: 当前生成起始帧索引 + initial_condition_frames: 初始条件帧数 + new_frames: 本次生成的新帧数 + total_generated: 已生成的总帧数 + use_real_poses: 是否使用真实的Sekai相机位姿 + direction: 相机运动方向,默认为"left" + + Returns: + camera_embedding: 形状为(M, 3*4 + 1)的torch张量, M为生成的总帧数 + """ + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + # 1帧初始 + 16帧4x + 2帧2x + 1帧1x + new_frames + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实Sekai camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + initial_condition_frames + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算Sekai camera序列长度:") + print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到start_frame+initial_condition_frames标记为condition + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + initial_condition_frames + new_frames, + framepack_needed_frames, + 30) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + + CONDITION_FRAMES = initial_condition_frames + STAGE_1 = new_frames//2 + STAGE_2 = new_frames - STAGE_1 + + if direction=="left": + print("--------------- LEFT TURNING MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 左转 + yaw_per_frame = 0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.00 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + elif direction=="right": + print("--------------- RIGHT TURNING MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 右转 + yaw_per_frame = -0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.00 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + elif direction=="forward_left": + print("--------------- FORWARD LEFT MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 左转 + yaw_per_frame = 0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.03 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + elif direction=="forward_right": + print("--------------- FORWARD RIGHT MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 右转 + yaw_per_frame = -0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.03 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + elif direction=="s_curve": + print("--------------- S CURVE MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1: + # 左转 + yaw_per_frame = 0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.03 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 右转 + yaw_per_frame = -0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.03 + # 轻微向左漂移,保持惯性 + if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3: + radius_shift = -0.01 + else: + radius_shift = 0.00 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + pose[0, 3] = radius_shift + + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + elif direction=="left_right": + print("--------------- LEFT RIGHT MODE ---------------") + relative_poses = [] + for i in range(max_needed_frames): + if i < CONDITION_FRAMES: + # 输入的条件帧默认的相机位姿为零运动 + pose = np.eye(4, dtype=np.float32) + elif i < CONDITION_FRAMES+STAGE_1: + # 左转 + yaw_per_frame = 0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.00 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + + elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: + # 右转 + yaw_per_frame = -0.03 + + # 旋转矩阵 + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 前进 + forward_speed = 0.00 + + pose = np.eye(4, dtype=np.float32) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + pose[2, 3] = -forward_speed + + else: + # 超出条件帧与目标帧的部分,保持静止 + pose = np.eye(4, dtype=np.float32) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + else: + raise ValueError(f"未定义的相机运动方向: {direction}") + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_openx_camera_embeddings_sliding( + encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: + print("🔧 使用OpenX真实camera数据") + cam_extrinsic = encoded_data['cam_emb']['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + initial_condition_frames + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX使用4倍间隔,类似sekai但处理更短的序列 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到start_frame + initial_condition_frames标记为condition + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + initial_condition_frames + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作运动模式 - 较小的运动幅度 + # 模拟机器人手臂的精细操作运动 + roll_per_frame = 0.02 # 轻微翻滚 + pitch_per_frame = 0.01 # 轻微俯仰 + yaw_per_frame = 0.015 # 轻微偏航 + forward_speed = 0.003 # 较慢的前进速度 + + pose = np.eye(4, dtype=np.float32) + + # 复合旋转 - 模拟机器人手臂的复杂运动 + # 绕X轴旋转(roll) + cos_roll = np.cos(roll_per_frame) + sin_roll = np.sin(roll_per_frame) + # 绕Y轴旋转(pitch) + cos_pitch = np.cos(pitch_per_frame) + sin_pitch = np.sin(pitch_per_frame) + # 绕Z轴旋转(yaw) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 简化的复合旋转矩阵(ZYX顺序) + pose[0, 0] = cos_yaw * cos_pitch + pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll + pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll + pose[1, 0] = sin_yaw * cos_pitch + pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll + pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll + pose[2, 0] = -sin_pitch + pose[2, 1] = cos_pitch * sin_roll + pose[2, 2] = cos_pitch * cos_roll + + # 平移 - 模拟机器人操作的精细移动 + pose[0, 3] = forward_speed * 0.5 # X方向轻微移动 + pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动 + pose[2, 3] = -forward_speed # Z方向(深度)主要移动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_nuscenes_camera_embeddings_sliding( + scene_info, start_frame, initial_condition_frames, new_frames): + """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if scene_info is not None and 'keyframe_poses' in scene_info: + print("🔧 使用NuScenes真实pose数据") + keyframe_poses = scene_info['keyframe_poses'] + + if len(keyframe_poses) == 0: + print("⚠️ NuScenes keyframe_poses为空,使用零pose") + max_needed_frames = max(framepack_needed_frames, 30) + + pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) + + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + # 使用第一个pose作为参考 + reference_pose = keyframe_poses[0] + + max_needed_frames = max(framepack_needed_frames, 30) + + pose_vecs = [] + for i in range(max_needed_frames): + if i < len(keyframe_poses): + current_pose = keyframe_poses[i] + + # 计算相对位移 + translation = torch.tensor( + np.array(current_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + # 计算相对旋转(简化版本) + rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + else: + # 超出范围,使用零pose + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + ], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7] + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用NuScenes合成pose数据") + max_needed_frames = max(framepack_needed_frames, 30) + + # 创建合成运动序列 + pose_vecs = [] + for i in range(max_needed_frames): + # 左转运动模式 - 类似城市驾驶中的左转弯 + angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯) + radius = 15.0 # 较大的转弯半径,更符合汽车转弯 + + # 计算圆弧轨迹上的位置 + x = radius * np.sin(angle) + y = 0.0 # 保持水平面运动 + z = radius * (1 - np.cos(angle)) + + translation = torch.tensor([x, y, z], dtype=torch.float32) + + # 车辆朝向 - 始终沿着轨迹切线方向 + yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角 + # 四元数表示绕Y轴的旋转 + rotation = torch.tensor([ + np.cos(yaw/2), # w (实部) + 0.0, # x + 0.0, # y + np.sin(yaw/2) # z (虚部,绕Y轴) + ], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + initial_condition_frames, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera_moe( + history_latents, + target_frames_to_generate, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames=49): + """FramePack滑动窗口机制 - MoE版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + # 1帧起始 + 16帧4x + 2帧2x + 1帧1x + target_frames_to_generate + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + print(f"⚠️ camera_embedding长度不足,进行零补齐: 当前长度 {camera_embedding_full.shape[0]}, 需要长度 {total_indices_length}") + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = torch.zeros( + total_indices_length, + camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, + device=camera_embedding_full.device) + + # 历史条件帧的相机位姿 + history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone() + combined_camera[19 - history_slice.shape[0]:19, :] = history_slice + + # 目标帧的相机位姿 + target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone() + combined_camera[19:19 + target_slice.shape[0], :] = target_slice + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 MoE Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - 模态类型: {modality_type}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'modality_type': modality_type, # 新增模态类型信息 + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + +def overlay_controls(frame_img, pose_vec, icons): + """ + 根据相机位姿在帧上叠加控制图标(WASD 和箭头) + pose_vec: 12 个元素(展平的 3x4 矩阵)+ mask + """ + if pose_vec is None or np.all(pose_vec[:12] == 0): + return frame_img + + # 提取平移向量(基于展平的 3x4 矩阵的索引) + # [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz] + tx = pose_vec[3] + # ty = pose_vec[7] + tz = pose_vec[11] + + # 提取旋转(偏航和俯仰) + # 偏航:绕 Y 轴。sin(偏航) = r02, cos(偏航) = r00 + r00 = pose_vec[0] + r02 = pose_vec[2] + yaw = np.arctan2(r02, r00) + + # 俯仰:绕 X 轴。sin(俯仰) = -r12, cos(俯仰) = r22 + r12 = pose_vec[6] + r22 = pose_vec[10] + pitch = np.arctan2(-r12, r22) + + # 按键激活的阈值 + TRANS_THRESH = 0.01 + ROT_THRESH = 0.005 + + # 确定按键状态 + # 平移(WASD) + # 假设 -Z 为前进,+X 为右 + is_forward = tz < -TRANS_THRESH + is_backward = tz > TRANS_THRESH + is_left = tx < -TRANS_THRESH + is_right = tx > TRANS_THRESH + + # 旋转(箭头) + # 偏航:+ 为左,- 为右 + is_turn_left = yaw > ROT_THRESH + is_turn_right = yaw < -ROT_THRESH + + # 俯仰:+ 为下,- 为上 + is_turn_up = pitch < -ROT_THRESH + is_turn_down = pitch > ROT_THRESH + + W, H = frame_img.size + spacing = 60 + + def paste_icon(name_active, name_inactive, is_active, x, y): + name = name_active if is_active else name_inactive + if name in icons: + icon = icons[name] + # 使用 alpha 通道粘贴 + frame_img.paste(icon, (int(x), int(y)), icon) + + # 叠加 WASD(左下角) + base_x_right = 100 + base_y = H - 100 + + # W + paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing) + # A + paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y) + # S + paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y) + # D + paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y) + + # 叠加 ↑↓←→(右下角) + base_x_left = W - 150 + + # ↑ + paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing) + # ← + paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y) + # ↓ + paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y) + # → + paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y) + + return frame_img + + +def inference_moe_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="../examples/output_videos/output_moe_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + modality_type="sekai", # "sekai" 或 "nuscenes" + use_real_poses=True, + scene_info_path=None, # 对于NuScenes数据集 + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0, + # MoE参数 + moe_num_experts=4, + moe_top_k=2, + moe_hidden_dim=None, + direction="left", + use_gt_prompt=True, + add_icons=False +): + """ + MoE FramePack滑动窗口视频生成 - 支持多模态 + """ + # 创建输出目录 + dir_path = os.path.dirname(output_path) + os.makedirs(dir_path, exist_ok=True) + + print(f"🔧 MoE FramePack滑动窗口生成开始...") + print(f"模态类型: {modality_type}") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加传统camera编码器(兼容性) + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 添加MoE组件 + moe_config = { + "num_experts": moe_num_experts, + "top_k": moe_top_k, + "hidden_dim": moe_hidden_dim or dim * 2, + "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask + "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask + "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai) + } + add_moe_components(pipe.dit, moe_config) + + # 5. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件 + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + # 设置去噪步数 + pipe.scheduler.set_timesteps(50) + + # 6. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 7. 编码prompt - 支持CFG + if use_gt_prompt and 'prompt_emb' in encoded_data: + print("✅ 使用预编码的GT prompt embedding") + prompt_emb_pos = encoded_data['prompt_emb'] + # 将prompt_emb移到正确的设备和数据类型 + if 'context' in prompt_emb_pos: + prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype) + if 'context_mask' in prompt_emb_pos: + prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype) + + # 如果使用Text CFG,生成负向prompt + if text_guidance_scale > 1.0: + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}") + else: + prompt_emb_neg = None + print("不使用Text CFG") + + # 🔧 打印GT prompt文本(如果有) + if 'prompt' in encoded_data['prompt_emb']: + gt_prompt_text = encoded_data['prompt_emb']['prompt'] + print(f"📝 GT Prompt文本: {gt_prompt_text}") + else: + # 使用传入的prompt参数重新编码 + print(f"🔄 重新编码prompt: {prompt}") + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 8. 加载场景信息(对于NuScenes) + scene_info = None + if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + print(f"加载NuScenes场景信息: {scene_info_path}") + + # 9. 预生成完整的camera embedding序列 + if modality_type == "sekai": + camera_embedding_full = generate_sekai_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + start_frame, + initial_condition_frames, + total_frames_to_generate, + 0, + use_real_poses=use_real_poses, + direction=direction + ).to(device, dtype=model_dtype) + elif modality_type == "nuscenes": + camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( + scene_info, + start_frame, + initial_condition_frames, + total_frames_to_generate + ).to(device, dtype=model_dtype) + elif modality_type == "openx": + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data, + start_frame, + initial_condition_frames, + total_frames_to_generate, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + else: + raise ValueError(f"不支持的模态类型: {modality_type}") + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 10. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 11. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - MoE版本 + framepack_data = prepare_framepack_sliding_window_with_camera_moe( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 准备modality_inputs + modality_inputs = {modality_type: camera_embedding} + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + modality_inputs=modality_inputs_uncond, # MoE无条件模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + noise_pred_text_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond, moe_loess= pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 12. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path} ...") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + icons = {} + video_camera_poses = None + if add_icons: + # 加载用于叠加的图标资源 + icons_dir = os.path.join(ROOT_DIR, 'icons') + icon_names = ['move_forward.png', 'not_move_forward.png', + 'move_backward.png', 'not_move_backward.png', + 'move_left.png', 'not_move_left.png', + 'move_right.png', 'not_move_right.png', + 'turn_up.png', 'not_turn_up.png', + 'turn_down.png', 'not_turn_down.png', + 'turn_left.png', 'not_turn_left.png', + 'turn_right.png', 'not_turn_right.png'] + for name in icon_names: + path = os.path.join(icons_dir, name) + if os.path.exists(path): + try: + icon = Image.open(path).convert("RGBA") + # 调整图标尺寸 + icon = icon.resize((50, 50), Image.Resampling.LANCZOS) + icons[name] = icon + except Exception as e: + print(f"Error loading icon {name}: {e}") + else: + print(f"Warning: Icon {name} not found at {path}") + + # 获取与视频帧对应的相机姿态 + time_compression_ratio = 4 + camera_poses = camera_embedding_full.detach().float().cpu().numpy() + video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)] + + with imageio.get_writer(output_path, fps=20) as writer: + for i, frame in enumerate(video_np): + # Convert to PIL for overlay + img = Image.fromarray(frame) + + if add_icons and video_camera_poses is not None and icons: + # Video frame i corresponds to camera_embedding_full[start_frame + i] + pose_idx = start_frame + i + if pose_idx < len(video_camera_poses): + pose_vec = video_camera_poses[pose_idx] + img = overlay_controls(img, pose_vec, icons) + + writer.append_data(np.array(img)) + + print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + print(f"使用模态: {modality_type}") + + +def main(): + parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + default="../examples/condition_pth/garden_1.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=1) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=24) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", default=False) + parser.add_argument("--dit_path", type=str, default=None, required=True, + help="path to the pretrained DiT MoE model checkpoint") + parser.add_argument("--output_path", type=str, + default='./examples/output_videos/output_moe_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, default=None, + help="text prompt for video generation") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--add_icons", action="store_true", default=False, + help="在生成的视频上叠加控制图标") + + # 模态类型参数 + parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], + default="sekai", help="模态类型:sekai 或 nuscenes 或 openx") + parser.add_argument("--scene_info_path", type=str, default=None, + help="NuScenes场景信息文件路径(仅用于nuscenes模态)") + + # CFG参数 + parser.add_argument("--use_camera_cfg", default=False, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + # MoE参数 + parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度") + parser.add_argument("--direction", type=str, default="left", help="生成视频的行进轨迹方向") + parser.add_argument("--use_gt_prompt", action="store_true", default=False, + help="使用数据集中的ground truth prompt embedding") + + args = parser.parse_args() + + print(f"🔧 MoE FramePack CFG生成设置:") + print(f"模态类型: {args.modality_type}") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"使用GT Prompt: {args.use_gt_prompt}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}") + print(f"DiT{args.dit_path}") + + # 验证NuScenes参数 + if args.modality_type == "nuscenes" and not args.scene_info_path: + print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据") + + inference_moe_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + modality_type=args.modality_type, + use_real_poses=args.use_real_poses, + scene_info_path=args.scene_info_path, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale, + # MoE参数 + moe_num_experts=args.moe_num_experts, + moe_top_k=args.moe_top_k, + moe_hidden_dim=args.moe_hidden_dim, + direction=args.direction, + use_gt_prompt=args.use_gt_prompt, + add_icons=args.add_icons + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_moe.py b/scripts/infer_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..dc690a903da084e76c9e6fa286edb9e85dd5f04f --- /dev/null +++ b/scripts/infer_moe.py @@ -0,0 +1,1023 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy +from scipy.spatial.transform import Rotation as R + + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + + +def calculate_relative_rotation(current_rotation, reference_rotation): + """计算相对旋转四元数 - NuScenes专用""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + return relative_rotation + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def replace_dit_model_in_manager(): + """替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + + +def add_moe_components(dit_model, moe_config): + """🔧 添加MoE相关组件 - 修正版本""" + if not hasattr(dit_model, 'moe_config'): + dit_model.moe_config = moe_config + print("✅ 添加了MoE配置到模型") + dit_model.top_k = moe_config.get("top_k", 1) + + # 为每个block动态添加MoE组件 + dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + unified_dim = moe_config.get("unified_dim", 25) + num_experts = moe_config.get("num_experts", 4) + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + dit_model.global_router = nn.Linear(unified_dim, num_experts) + + + for i, block in enumerate(dit_model.blocks): + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=moe_config.get("num_experts", 4), + top_k=moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") + + +def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """为Sekai数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实Sekai camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算Sekai camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用Sekai合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # 持续左转运动模式 + yaw_per_frame = -0.1 # 每帧左转(正角度表示左转) + forward_speed = 0.005 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: + print("🔧 使用OpenX真实camera数据") + cam_extrinsic = encoded_data['cam_emb']['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX使用4倍间隔,类似sekai但处理更短的序列 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作运动模式 - 较小的运动幅度 + # 模拟机器人手臂的精细操作运动 + roll_per_frame = 0.02 # 轻微翻滚 + pitch_per_frame = 0.01 # 轻微俯仰 + yaw_per_frame = 0.015 # 轻微偏航 + forward_speed = 0.003 # 较慢的前进速度 + + pose = np.eye(4, dtype=np.float32) + + # 复合旋转 - 模拟机器人手臂的复杂运动 + # 绕X轴旋转(roll) + cos_roll = np.cos(roll_per_frame) + sin_roll = np.sin(roll_per_frame) + # 绕Y轴旋转(pitch) + cos_pitch = np.cos(pitch_per_frame) + sin_pitch = np.sin(pitch_per_frame) + # 绕Z轴旋转(yaw) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 简化的复合旋转矩阵(ZYX顺序) + pose[0, 0] = cos_yaw * cos_pitch + pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll + pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll + pose[1, 0] = sin_yaw * cos_pitch + pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll + pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll + pose[2, 0] = -sin_pitch + pose[2, 1] = cos_pitch * sin_roll + pose[2, 2] = cos_pitch * cos_roll + + # 平移 - 模拟机器人操作的精细移动 + pose[0, 3] = forward_speed * 0.5 # X方向轻微移动 + pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动 + pose[2, 3] = -forward_speed # Z方向(深度)主要移动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames): + """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致""" + time_compression_ratio = 4 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + max_needed_frames = max(framepack_needed_frames, 30) + + if scene_info is not None and 'keyframe_poses' in scene_info: + print("🔧 使用NuScenes真实pose数据") + keyframe_poses = scene_info['keyframe_poses'] + # 生成所有需要的关键帧索引 + keyframe_indices = [] + for i in range(max_needed_frames + 1): # +1是因为需要前后两帧 + idx = (start_frame + i) * time_compression_ratio + keyframe_indices.append(idx) + keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices] + + pose_vecs = [] + for i in range(max_needed_frames): + pose_prev = keyframe_poses[keyframe_indices[i]] + pose_next = keyframe_poses[keyframe_indices[i+1]] + # 计算相对位移 + translation = torch.tensor( + np.array(pose_next['translation']) - np.array(pose_prev['translation']), + dtype=torch.float32 + ) + # 计算相对旋转 + relative_rotation = calculate_relative_rotation( + pose_next['rotation'], + pose_prev['rotation'] + ) + pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7] + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + camera_embedding = torch.cat([pose_sequence, mask], dim=1) + print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用NuScenes合成pose数据") + # 先生成绝对轨迹 + abs_translations = [] + abs_rotations = [] + for i in range(max_needed_frames + 1): # +1是为了后续做相对 + angle = -i * 0.12 + radius = 8.0 + x = radius * np.sin(angle) + y = 0.0 + z = radius * (1 - np.cos(angle)) + abs_translations.append(np.array([x, y, z], dtype=np.float32)) + yaw = angle + np.pi/2 + abs_rotations.append(np.array([ + np.cos(yaw/2), 0.0, 0.0, np.sin(yaw/2) + ], dtype=np.float32)) + + # 计算每帧相对上一帧的运动 + pose_vecs = [] + for i in range(max_needed_frames): + translation = torch.tensor(abs_translations[i+1] - abs_translations[i], dtype=torch.float32) + # 计算相对旋转 + q_next = abs_rotations[i+1] + q_prev = abs_rotations[i] + # 四元数相对旋转 + q_prev_inv = np.array([q_prev[0], -q_prev[1], -q_prev[2], -q_prev[3]], dtype=np.float32) + w1, x1, y1, z1 = q_prev_inv + w2, x2, y2, z2 = q_next + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ], dtype=torch.float32) + pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + camera_embedding = torch.cat([pose_sequence, mask], dim=1) + print(f"🔧 NuScenes合成相对pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49): + """FramePack滑动窗口机制 - MoE版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 MoE Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - 模态类型: {modality_type}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'modality_type': modality_type, # 新增模态类型信息 + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + + +def inference_moe_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="moe/infer_results/output_moe_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + modality_type="sekai", # "sekai" 或 "nuscenes" + use_real_poses=True, + scene_info_path=None, # 对于NuScenes数据集 + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0, + # MoE参数 + moe_num_experts=4, + moe_top_k=2, + moe_hidden_dim=None +): + """ + MoE FramePack滑动窗口视频生成 - 支持多模态 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 MoE FramePack滑动窗口生成开始...") + print(f"模态类型: {modality_type}") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加传统camera编码器(兼容性) + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 添加MoE组件 + moe_config = { + "num_experts": moe_num_experts, + "top_k": moe_top_k, + "hidden_dim": moe_hidden_dim or dim * 2, + "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask + "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask + "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai) + } + add_moe_components(pipe.dit, moe_config) + + # 5. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件 + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 6. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 7. 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 8. 加载场景信息(对于NuScenes) + scene_info = None + if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + print(f"加载NuScenes场景信息: {scene_info_path}") + + # 9. 预生成完整的camera embedding序列 + if modality_type == "sekai": + camera_embedding_full = generate_sekai_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + elif modality_type == "nuscenes": + camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( + scene_info, + 0, + max_history_frames, + 0 + ).to(device, dtype=model_dtype) + elif modality_type == "openx": + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data, + 0, + max_history_frames, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + else: + raise ValueError(f"不支持的模态类型: {modality_type}") + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 10. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 11. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - MoE版本 + framepack_data = prepare_framepack_sliding_window_with_camera_moe( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 准备modality_inputs + modality_inputs = {modality_type: camera_embedding} + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + modality_inputs=modality_inputs_uncond, # MoE无条件模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + noise_pred_text_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond, moe_loess= pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 12. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + print(f"使用模态: {modality_type}") + + +def main(): + parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth") + #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=24) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", default=True) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world ") + parser.add_argument("--device", type=str, default="cuda") + + # 模态类型参数 + parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai", + help="模态类型:sekai 或 nuscenes 或 openx") + parser.add_argument("--scene_info_path", type=str, default=None, + help="NuScenes场景信息文件路径(仅用于nuscenes模态)") + + # CFG参数 + parser.add_argument("--use_camera_cfg", default=False, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + # MoE参数 + parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度") + + args = parser.parse_args() + + print(f"🔧 MoE FramePack CFG生成设置:") + print(f"模态类型: {args.modality_type}") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}") + print(f"DiT{args.dit_path}") + + # 验证NuScenes参数 + if args.modality_type == "nuscenes" and not args.scene_info_path: + print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据") + + inference_moe_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + modality_type=args.modality_type, + use_real_poses=args.use_real_poses, + scene_info_path=args.scene_info_path, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale, + # MoE参数 + moe_num_experts=args.moe_num_experts, + moe_top_k=args.moe_top_k, + moe_hidden_dim=args.moe_hidden_dim + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_moe_spatialvid.py b/scripts/infer_moe_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..9fa8faf8c9e0804459cf03e2780e09f5cd659571 --- /dev/null +++ b/scripts/infer_moe_spatialvid.py @@ -0,0 +1,1008 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy +from scipy.spatial.transform import Rotation as R + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def replace_dit_model_in_manager(): + """替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + + +def add_moe_components(dit_model, moe_config): + """🔧 添加MoE相关组件 - 修正版本""" + if not hasattr(dit_model, 'moe_config'): + dit_model.moe_config = moe_config + print("✅ 添加了MoE配置到模型") + + # 为每个block动态添加MoE组件 + dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + unified_dim = moe_config.get("unified_dim", 25) + + for i, block in enumerate(dit_model.blocks): + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + + # Sekai模态处理器 - 输出unified_dim + block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + + # # NuScenes模态处理器 - 输出unified_dim + # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=moe_config.get("num_experts", 4), + top_k=moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") + + +def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """为Sekai数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实Sekai camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算Sekai camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose_matrix(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用Sekai合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # 持续左转运动模式 + yaw_per_frame = 0.05 # 每帧左转(正角度表示左转) + forward_speed = 0.005 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = -radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: + print("🔧 使用OpenX真实camera数据") + cam_extrinsic = encoded_data['cam_emb']['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX使用4倍间隔,类似sekai但处理更短的序列 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作运动模式 - 较小的运动幅度 + # 模拟机器人手臂的精细操作运动 + roll_per_frame = 0.02 # 轻微翻滚 + pitch_per_frame = 0.01 # 轻微俯仰 + yaw_per_frame = 0.015 # 轻微偏航 + forward_speed = 0.003 # 较慢的前进速度 + + pose = np.eye(4, dtype=np.float32) + + # 复合旋转 - 模拟机器人手臂的复杂运动 + # 绕X轴旋转(roll) + cos_roll = np.cos(roll_per_frame) + sin_roll = np.sin(roll_per_frame) + # 绕Y轴旋转(pitch) + cos_pitch = np.cos(pitch_per_frame) + sin_pitch = np.sin(pitch_per_frame) + # 绕Z轴旋转(yaw) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 简化的复合旋转矩阵(ZYX顺序) + pose[0, 0] = cos_yaw * cos_pitch + pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll + pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll + pose[1, 0] = sin_yaw * cos_pitch + pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll + pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll + pose[2, 0] = -sin_pitch + pose[2, 1] = cos_pitch * sin_roll + pose[2, 2] = cos_pitch * cos_roll + + # 平移 - 模拟机器人操作的精细移动 + pose[0, 3] = forward_speed * 0.5 # X方向轻微移动 + pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动 + pose[2, 3] = -forward_speed # Z方向(深度)主要移动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames): + """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if scene_info is not None and 'keyframe_poses' in scene_info: + print("🔧 使用NuScenes真实pose数据") + keyframe_poses = scene_info['keyframe_poses'] + + if len(keyframe_poses) == 0: + print("⚠️ NuScenes keyframe_poses为空,使用零pose") + max_needed_frames = max(framepack_needed_frames, 30) + + pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) + + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + # 使用第一个pose作为参考 + reference_pose = keyframe_poses[0] + + max_needed_frames = max(framepack_needed_frames, 30) + + pose_vecs = [] + for i in range(max_needed_frames): + if i < len(keyframe_poses): + current_pose = keyframe_poses[i] + + # 计算相对位移 + translation = torch.tensor( + np.array(current_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + # 计算相对旋转(简化版本) + rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + else: + # 超出范围,使用零pose + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + ], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7] + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用NuScenes合成pose数据") + max_needed_frames = max(framepack_needed_frames, 30) + + # 创建合成运动序列 + pose_vecs = [] + for i in range(max_needed_frames): + # 简单的前进运动 + translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进 + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转 + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49): + """FramePack滑动窗口机制 - MoE版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 MoE Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - 模态类型: {modality_type}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'modality_type': modality_type, # 新增模态类型信息 + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + + +def inference_moe_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="moe/infer_results/output_moe_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + modality_type="sekai", # "sekai" 或 "nuscenes" + use_real_poses=True, + scene_info_path=None, # 对于NuScenes数据集 + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0, + # MoE参数 + moe_num_experts=4, + moe_top_k=2, + moe_hidden_dim=None +): + """ + MoE FramePack滑动窗口视频生成 - 支持多模态 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 MoE FramePack滑动窗口生成开始...") + print(f"模态类型: {modality_type}") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加传统camera编码器(兼容性) + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 添加MoE组件 + moe_config = { + "num_experts": moe_num_experts, + "top_k": moe_top_k, + "hidden_dim": moe_hidden_dim or dim * 2, + "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask + "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask + "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai) + } + add_moe_components(pipe.dit, moe_config) + + # 5. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件 + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 6. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 7. 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 8. 加载场景信息(对于NuScenes) + scene_info = None + if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + print(f"加载NuScenes场景信息: {scene_info_path}") + + # 9. 预生成完整的camera embedding序列 + if modality_type == "sekai": + camera_embedding_full = generate_sekai_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + elif modality_type == "nuscenes": + camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( + scene_info, + 0, + max_history_frames, + 0 + ).to(device, dtype=model_dtype) + elif modality_type == "openx": + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data, + 0, + max_history_frames, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + else: + raise ValueError(f"不支持的模态类型: {modality_type}") + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 10. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 11. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - MoE版本 + framepack_data = prepare_framepack_sliding_window_with_camera_moe( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 准备modality_inputs + modality_inputs = {modality_type: camera_embedding} + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + modality_inputs=modality_inputs_uncond, # MoE无条件模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + noise_pred_text_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 12. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + print(f"使用模态: {modality_type}") + + +def main(): + parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth") + default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=8) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=False) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid/step250_moe.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A man enter the room") + parser.add_argument("--device", type=str, default="cuda") + + # 模态类型参数 + parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai", + help="模态类型:sekai 或 nuscenes 或 openx") + parser.add_argument("--scene_info_path", type=str, default=None, + help="NuScenes场景信息文件路径(仅用于nuscenes模态)") + + # CFG参数 + parser.add_argument("--use_camera_cfg", default=True, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + # MoE参数 + parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度") + + args = parser.parse_args() + + print(f"🔧 MoE FramePack CFG生成设置:") + print(f"模态类型: {args.modality_type}") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}") + + # 验证NuScenes参数 + if args.modality_type == "nuscenes" and not args.scene_info_path: + print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据") + + inference_moe_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + modality_type=args.modality_type, + use_real_poses=args.use_real_poses, + scene_info_path=args.scene_info_path, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale, + # MoE参数 + moe_num_experts=args.moe_num_experts, + moe_top_k=args.moe_top_k, + moe_hidden_dim=args.moe_hidden_dim + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_moe_test.py b/scripts/infer_moe_test.py new file mode 100644 index 0000000000000000000000000000000000000000..acc9e77f79da0b140b8d63876363355b43c28b72 --- /dev/null +++ b/scripts/infer_moe_test.py @@ -0,0 +1,976 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def replace_dit_model_in_manager(): + """替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + + +def add_moe_components(dit_model, moe_config): + """🔧 添加MoE相关组件 - 修正版本""" + if not hasattr(dit_model, 'moe_config'): + dit_model.moe_config = moe_config + print("✅ 添加了MoE配置到模型") + + # 为每个block动态添加MoE组件 + dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + unified_dim = moe_config.get("unified_dim", 25) + + for i, block in enumerate(dit_model.blocks): + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + + # Sekai模态处理器 - 输出unified_dim + block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + + # # NuScenes模态处理器 - 输出unified_dim + # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=moe_config.get("num_experts", 4), + top_k=moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") + + +def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """为Sekai数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实Sekai camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算Sekai camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用Sekai合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # 持续左转运动模式 + yaw_per_frame = 0.05 # 每帧左转(正角度表示左转) + forward_speed = 0.005 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = -radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: + print("🔧 使用OpenX真实camera数据") + cam_extrinsic = encoded_data['cam_emb']['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX使用4倍间隔,类似sekai但处理更短的序列 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作运动模式 - 较小的运动幅度 + # 模拟机器人手臂的精细操作运动 + roll_per_frame = 0.02 # 轻微翻滚 + pitch_per_frame = 0.01 # 轻微俯仰 + yaw_per_frame = 0.015 # 轻微偏航 + forward_speed = 0.003 # 较慢的前进速度 + + pose = np.eye(4, dtype=np.float32) + + # 复合旋转 - 模拟机器人手臂的复杂运动 + # 绕X轴旋转(roll) + cos_roll = np.cos(roll_per_frame) + sin_roll = np.sin(roll_per_frame) + # 绕Y轴旋转(pitch) + cos_pitch = np.cos(pitch_per_frame) + sin_pitch = np.sin(pitch_per_frame) + # 绕Z轴旋转(yaw) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 简化的复合旋转矩阵(ZYX顺序) + pose[0, 0] = cos_yaw * cos_pitch + pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll + pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll + pose[1, 0] = sin_yaw * cos_pitch + pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll + pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll + pose[2, 0] = -sin_pitch + pose[2, 1] = cos_pitch * sin_roll + pose[2, 2] = cos_pitch * cos_roll + + # 平移 - 模拟机器人操作的精细移动 + pose[0, 3] = forward_speed * 0.5 # X方向轻微移动 + pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动 + pose[2, 3] = -forward_speed # Z方向(深度)主要移动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames): + """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if scene_info is not None and 'keyframe_poses' in scene_info: + print("🔧 使用NuScenes真实pose数据") + keyframe_poses = scene_info['keyframe_poses'] + + if len(keyframe_poses) == 0: + print("⚠️ NuScenes keyframe_poses为空,使用零pose") + max_needed_frames = max(framepack_needed_frames, 30) + + pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) + + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + # 使用第一个pose作为参考 + reference_pose = keyframe_poses[0] + + max_needed_frames = max(framepack_needed_frames, 30) + + pose_vecs = [] + for i in range(max_needed_frames): + if i < len(keyframe_poses): + current_pose = keyframe_poses[i] + + # 计算相对位移 + translation = torch.tensor( + np.array(current_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + # 计算相对旋转(简化版本) + rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + else: + # 超出范围,使用零pose + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + ], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7] + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用NuScenes合成pose数据") + max_needed_frames = max(framepack_needed_frames, 30) + + # 创建合成运动序列 + pose_vecs = [] + for i in range(max_needed_frames): + # 简单的前进运动 + translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进 + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转 + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49): + """FramePack滑动窗口机制 - MoE版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 MoE Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - 模态类型: {modality_type}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'modality_type': modality_type, # 新增模态类型信息 + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + + +def inference_moe_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="moe/infer_results/output_moe_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + modality_type="sekai", # "sekai" 或 "nuscenes" + use_real_poses=True, + scene_info_path=None, # 对于NuScenes数据集 + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0, + # MoE参数 + moe_num_experts=4, + moe_top_k=2, + moe_hidden_dim=None +): + """ + MoE FramePack滑动窗口视频生成 - 支持多模态 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 MoE FramePack滑动窗口生成开始...") + print(f"模态类型: {modality_type}") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加传统camera编码器(兼容性) + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 添加MoE组件 + moe_config = { + "num_experts": moe_num_experts, + "top_k": moe_top_k, + "hidden_dim": moe_hidden_dim or dim * 2, + "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask + "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask + "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai) + } + add_moe_components(pipe.dit, moe_config) + + # 5. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件 + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 6. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 7. 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 8. 加载场景信息(对于NuScenes) + scene_info = None + if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + print(f"加载NuScenes场景信息: {scene_info_path}") + + # 9. 预生成完整的camera embedding序列 + if modality_type == "sekai": + camera_embedding_full = generate_sekai_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + elif modality_type == "nuscenes": + camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( + scene_info, + 0, + max_history_frames, + 0 + ).to(device, dtype=model_dtype) + elif modality_type == "openx": + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data, + 0, + max_history_frames, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + else: + raise ValueError(f"不支持的模态类型: {modality_type}") + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 10. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 11. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - MoE版本 + framepack_data = prepare_framepack_sliding_window_with_camera_moe( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 准备modality_inputs + modality_inputs = {modality_type: camera_embedding} + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + modality_inputs=modality_inputs_uncond, # MoE无条件模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + noise_pred_text_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred, moe_loss = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 12. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + print(f"使用模态: {modality_type}") + + +def main(): + parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth") + #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=40) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=False) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1000_moe.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world") + parser.add_argument("--device", type=str, default="cuda") + + # 模态类型参数 + parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai", + help="模态类型:sekai 或 nuscenes 或 openx") + parser.add_argument("--scene_info_path", type=str, default=None, + help="NuScenes场景信息文件路径(仅用于nuscenes模态)") + + # CFG参数 + parser.add_argument("--use_camera_cfg", default=True, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + # MoE参数 + parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度") + + args = parser.parse_args() + + print(f"🔧 MoE FramePack CFG生成设置:") + print(f"模态类型: {args.modality_type}") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}") + + # 验证NuScenes参数 + if args.modality_type == "nuscenes" and not args.scene_info_path: + print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据") + + inference_moe_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + modality_type=args.modality_type, + use_real_poses=args.use_real_poses, + scene_info_path=args.scene_info_path, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale, + # MoE参数 + moe_num_experts=args.moe_num_experts, + moe_top_k=args.moe_top_k, + moe_hidden_dim=args.moe_hidden_dim + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_nus.py b/scripts/infer_nus.py new file mode 100644 index 0000000000000000000000000000000000000000..8c64c007c4e1aa7e3c2e29380ae0e7637f00a70b --- /dev/null +++ b/scripts/infer_nus.py @@ -0,0 +1,500 @@ +import os +import torch +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import torch.nn as nn +from pose_classifier import PoseClassifier + + +def load_video_frames(video_path, num_frames=20, height=900, width=1600): + """Load video frames and preprocess them""" + frame_process = v2.Compose([ + # v2.CenterCrop(size=(height, width)), + # v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(image): + w, h = image.size + # scale = max(width / w, height / h) + image = v2.functional.resize( + image, + (round(480), round(832)), + interpolation=v2.InterpolationMode.BILINEAR + ) + return image + + reader = imageio.get_reader(video_path) + frames = [] + + for i, frame_data in enumerate(reader): + if i >= num_frames: + break + frame = Image.fromarray(frame_data) + frame = crop_and_resize(frame) + frame = frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + +def calculate_relative_rotation(current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + # 计算参考旋转的逆 (q_ref^-1) + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + +def generate_direction_poses(direction="left", target_frames=10, condition_frames=20): + """ + 根据指定方向生成pose类别embedding,包含condition和target帧 + Args: + direction: 'forward', 'backward', 'left_turn', 'right_turn' + target_frames: 目标帧数 + condition_frames: 条件帧数 + """ + classifier = PoseClassifier() + + total_frames = condition_frames + target_frames + print(f"conditon{condition_frames}") + print(f"target{target_frames}") + poses = [] + + # 🔧 生成condition帧的pose(相对稳定的前向运动) + for i in range(condition_frames): + t = i / max(1, condition_frames - 1) # 0 to 1 + + # condition帧保持相对稳定的前向运动 + translation = [-t * 0.5, 0.0, 0.0] # 缓慢前进 + rotation = [1.0, 0.0, 0.0, 0.0] # 无旋转 + frame_type = 0.0 # condition + + pose_vec = translation + rotation + [frame_type] # 8D vector + poses.append(pose_vec) + + # 🔧 生成target帧的pose(根据指定方向) + for i in range(target_frames): + t = i / max(1, target_frames - 1) # 0 to 1 + + if direction == "forward": + # 前进:x负方向移动,无旋转 + translation = [-(condition_frames * 0.5 + t * 2.0), 0.0, 0.0] + rotation = [1.0, 0.0, 0.0, 0.0] # 单位四元数 + + elif direction == "backward": + # 后退:x正方向移动,无旋转 + translation = [-(condition_frames * 0.5) + t * 2.0, 0.0, 0.0] + rotation = [1.0, 0.0, 0.0, 0.0] + + elif direction == "left_turn": + # 左转:前进 + 绕z轴正向旋转 + translation = [-(condition_frames * 0.5 + t * 1.5), t * 0.5, 0.0] # 前进并稍微左移 + yaw = t * 0.3 # 左转角度(弧度) + rotation = [ + np.cos(yaw/2), # w + 0.0, # x + 0.0, # y + np.sin(yaw/2) # z (左转为正) + ] + + elif direction == "right_turn": + # 右转:前进 + 绕z轴负向旋转 + translation = [-(condition_frames * 0.5 + t * 1.5), -t * 0.5, 0.0] # 前进并稍微右移 + yaw = -t * 0.3 # 右转角度(弧度) + rotation = [ + np.cos(abs(yaw)/2), # w + 0.0, # x + 0.0, # y + np.sin(yaw/2) # z (右转为负) + ] + else: + raise ValueError(f"Unknown direction: {direction}") + + frame_type = 1.0 # target + pose_vec = translation + rotation + [frame_type] # 8D vector + poses.append(pose_vec) + + pose_sequence = torch.tensor(poses, dtype=torch.float32) + + # 🔧 只对target部分进行分类(前7维,去掉frame type) + target_pose_sequence = pose_sequence[condition_frames:, :7] + + # 🔧 使用增强的embedding生成方法 + condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # condition都是forward + target_classes = classifier.classify_pose_sequence(target_pose_sequence) + full_classes = torch.cat([condition_classes, target_classes], dim=0) + + # 创建增强的embedding + class_embeddings = create_enhanced_class_embedding_for_inference( + full_classes, pose_sequence, embed_dim=512 + ) + + print(f"Generated {direction} poses:") + print(f" Total frames: {total_frames} (condition: {condition_frames}, target: {target_frames})") + analysis = classifier.analyze_pose_sequence(target_pose_sequence) + print(f" Target class distribution: {analysis['class_distribution']}") + print(f" Target motion segments: {len(analysis['motion_segments'])}") + + return class_embeddings + +def create_enhanced_class_embedding_for_inference(class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor: + """推理时创建增强的类别embedding""" + num_classes = 4 + num_frames = len(class_labels) + + # 基础的方向embedding + direction_vectors = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], # forward + [-1.0, 0.0, 0.0, 0.0], # backward + [0.0, 1.0, 0.0, 0.0], # left_turn + [0.0, -1.0, 0.0, 0.0], # right_turn + ], dtype=torch.float32) + + # One-hot编码 + one_hot = torch.zeros(num_frames, num_classes) + one_hot.scatter_(1, class_labels.unsqueeze(1), 1) + + # 基于方向向量的基础embedding + base_embeddings = one_hot @ direction_vectors # [num_frames, 4] + + # 添加frame type信息 + frame_types = pose_sequence[:, -1] # 最后一维是frame type + frame_type_embeddings = torch.zeros(num_frames, 2) + frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition + frame_type_embeddings[:, 1] = (frame_types == 1).float() # target + + # 添加pose的几何信息 + translations = pose_sequence[:, :3] # [num_frames, 3] + rotations = pose_sequence[:, 3:7] # [num_frames, 4] + + # 组合所有特征 + combined_features = torch.cat([ + base_embeddings, # [num_frames, 4] + frame_type_embeddings, # [num_frames, 2] + translations, # [num_frames, 3] + rotations, # [num_frames, 4] + ], dim=1) # [num_frames, 13] + + # 扩展到目标维度 + if embed_dim > 13: + expand_matrix = torch.randn(13, embed_dim) * 0.1 + expand_matrix[:13, :13] = torch.eye(13) + embeddings = combined_features @ expand_matrix + else: + embeddings = combined_features[:, :embed_dim] + + return embeddings + +def generate_poses_from_file(poses_path, target_frames=10): + """从poses.json文件生成类别embedding""" + classifier = PoseClassifier() + + with open(poses_path, 'r') as f: + poses_data = json.load(f) + + target_relative_poses = poses_data['target_relative_poses'] + + if not target_relative_poses: + print("No poses found in file, using forward direction") + return generate_direction_poses("forward", target_frames) + + # 创建pose序列 + pose_vecs = [] + for i in range(target_frames): + if len(target_relative_poses) == 1: + pose_data = target_relative_poses[0] + else: + pose_idx = min(i * len(target_relative_poses) // target_frames, + len(target_relative_poses) - 1) + pose_data = target_relative_poses[pose_idx] + + # 提取相对位移和旋转 + translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32) + current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32) + reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32) + + # 计算相对旋转 + relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation) + + # 组合为7D向量 + pose_vec = torch.cat([translation, relative_rotation], dim=0) + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 使用分类器生成class embedding + class_embeddings = classifier.create_class_embedding( + classifier.classify_pose_sequence(pose_sequence), + embed_dim=512 + ) + + print(f"Generated poses from file:") + analysis = classifier.analyze_pose_sequence(pose_sequence) + print(f" Class distribution: {analysis['class_distribution']}") + print(f" Motion segments: {len(analysis['motion_segments'])}") + + return class_embeddings + +def inference_nuscenes_video( + condition_video_path, + dit_path, + text_encoder_path, + vae_path, + output_path="nus/infer_results/output_nuscenes.mp4", + condition_frames=20, + target_frames=3, + height=900, + width=1600, + device="cuda", + prompt="A car driving scene captured by front camera", + poses_path=None, + direction="forward" +): + """ + 使用方向类别控制的推理函数 - 支持condition和target pose区分 + """ + os.makedirs(os.path.dirname(output_path),exist_ok=True) + + print(f"Setting up models for {direction} movement...") + + # 1. Load models (same as before) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # Add camera components to DiT + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(512, dim) # 保持512维embedding + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # Load trained DiT weights + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + pipe.scheduler.set_timesteps(50) + + print("Loading condition video...") + + # Load condition video + condition_video = load_video_frames( + condition_video_path, + num_frames=condition_frames, + height=height, + width=width + ) + + if condition_video is None: + raise ValueError(f"Failed to load condition video from {condition_video_path}") + + condition_video = condition_video.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + print("Processing poses...") + + # 🔧 修改:生成包含condition和target的pose embedding + print(f"Generating {direction} movement poses...") + camera_embedding = generate_direction_poses( + direction=direction, + target_frames=target_frames, + condition_frames=int(condition_frames/4) # 压缩后的condition帧数 + ) + + camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Generated poses for direction: {direction}") + + print("Encoding inputs...") + + # Encode text prompt + prompt_emb = pipe.encode_prompt(prompt) + + # Encode condition video + condition_latents = pipe.encode_video(condition_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))[0] + + print("Generating video...") + + # Generate target latents + batch_size = 1 + channels = condition_latents.shape[0] + latent_height = condition_latents.shape[2] + latent_width = condition_latents.shape[3] + target_height, target_width = 60, 104 # 根据你的需求调整 + + if latent_height > target_height or latent_width > target_width: + # 中心裁剪 + h_start = (latent_height - target_height) // 2 + w_start = (latent_width - target_width) // 2 + condition_latents = condition_latents[:, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + latent_height = target_height + latent_width = target_width + condition_latents = condition_latents.to(device, dtype=pipe.torch_dtype) + condition_latents = condition_latents.unsqueeze(0) + condition_latents = condition_latents + 0.05 * torch.randn_like(condition_latents) # 添加少量噪声以增加多样性 + + # Initialize target latents with noise + target_latents = torch.randn( + batch_size, channels, target_frames, latent_height, latent_width, + device=device, dtype=pipe.torch_dtype + ) + print(target_latents.shape) + print(camera_embedding.shape) + # Combine condition and target latents + combined_latents = torch.cat([condition_latents, target_latents], dim=2) + print(combined_latents.shape) + + # Prepare extra inputs + extra_input = pipe.prepare_extra_input(combined_latents) + + # Denoising loop + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + print(f"Denoising step {i+1}/{len(timesteps)}") + + # Prepare timestep + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + # Predict noise + with torch.no_grad(): + noise_pred = pipe.dit( + combined_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + **prompt_emb, + **extra_input + ) + + # Update only target part + target_noise_pred = noise_pred[:, :, int(condition_frames/4):, :, :] + target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents) + + # Update combined latents + combined_latents[:, :, int(condition_frames/4):, :, :] = target_latents + + print("Decoding video...") + + # Decode final video + final_video = torch.cat([condition_latents, target_latents], dim=2) + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + # Save video + print(f"Saving video to {output_path}") + + # Convert to numpy and save + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() # 转换为 Float32 + video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"Video generation completed! Saved to {output_path}") + +def main(): + parser = argparse.ArgumentParser(description="NuScenes Video Generation Inference with Direction Control") + parser.add_argument("--condition_video", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032/right.mp4", + help="Path to condition video") + parser.add_argument("--direction", type=str, default="left_turn", + choices=["forward", "backward", "left_turn", "right_turn"], + help="Direction of camera movement") + parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt", + help="Path to trained DiT checkpoint") + parser.add_argument("--text_encoder_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + help="Path to text encoder") + parser.add_argument("--vae_path", type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + help="Path to VAE") + parser.add_argument("--output_path", type=str, default="nus/infer_results-15000/right_left.mp4", + help="Output video path") + parser.add_argument("--poses_path", type=str, default=None, + help="Path to poses.json file (optional, will use direction if not provided)") + parser.add_argument("--prompt", type=str, + default="A car driving scene captured by front camera", + help="Text prompt for generation") + parser.add_argument("--condition_frames", type=int, default=40, + help="Number of condition frames") + # 这个是原始帧数 + parser.add_argument("--target_frames", type=int, default=8, + help="Number of target frames to generate") + # 这个要除以4 + parser.add_argument("--height", type=int, default=900, + help="Video height") + parser.add_argument("--width", type=int, default=1600, + help="Video width") + parser.add_argument("--device", type=str, default="cuda", + help="Device to run inference on") + + args = parser.parse_args() + + condition_video_path = args.condition_video + input_filename = os.path.basename(condition_video_path) + output_dir = "nus/infer_results" + os.makedirs(output_dir, exist_ok=True) + + # 🔧 修改:在输出文件名中包含方向信息 + if args.output_path is None: + name_parts = os.path.splitext(input_filename) + output_filename = f"{name_parts[0]}_{args.direction}{name_parts[1]}" + output_path = os.path.join(output_dir, output_filename) + else: + output_path = args.output_path + + print(f"Output video will be saved to: {output_path}") + inference_nuscenes_video( + condition_video_path=args.condition_video, + dit_path=args.dit_path, + text_encoder_path=args.text_encoder_path, + vae_path=args.vae_path, + output_path=output_path, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + height=args.height, + width=args.width, + device=args.device, + prompt=args.prompt, + poses_path=args.poses_path, + direction=args.direction # 🔧 新增 + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_openx.py b/scripts/infer_openx.py new file mode 100644 index 0000000000000000000000000000000000000000..30876f253dacbbeca5c48b1772eeb3d2ed69a8a2 --- /dev/null +++ b/scripts/infer_openx.py @@ -0,0 +1,614 @@ +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +import os +import torch +import torch.nn as nn +import argparse +import numpy as np +import imageio +import copy +import random + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + +def generate_openx_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实OpenX camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX特有:每隔4帧 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作模式 - 稳定的小幅度运动 + # 模拟机器人手臂的精细操作 + forward_speed = 0.001 # 每帧前进距离(很小,因为是精细操作) + lateral_motion = 0.0005 * np.sin(i * 0.05) # 轻微的左右移动 + vertical_motion = 0.0003 * np.cos(i * 0.1) # 轻微的上下移动 + + # 旋转变化(模拟视角微调) + yaw_change = 0.01 * np.sin(i * 0.03) # 轻微的偏航 + pitch_change = 0.008 * np.cos(i * 0.04) # 轻微的俯仰 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴和X轴的小角度旋转) + cos_yaw = np.cos(yaw_change) + sin_yaw = np.sin(yaw_change) + cos_pitch = np.cos(pitch_change) + sin_pitch = np.sin(pitch_change) + + # 组合旋转(先pitch后yaw) + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[1, 1] = cos_pitch + pose[1, 2] = -sin_pitch + pose[2, 0] = -sin_yaw + pose[2, 1] = sin_pitch + pose[2, 2] = cos_yaw * cos_pitch + + # 平移(精细操作的小幅度移动) + pose[0, 3] = lateral_motion # X轴(左右) + pose[1, 3] = vertical_motion # Y轴(上下) + pose[2, 3] = -forward_speed # Z轴(前后,负值表示前进) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49): + """FramePack滑动窗口机制 - OpenX版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 OpenX Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + +def inference_openx_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="openx_results/output_openx_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of robotic manipulation task with camera movement", + use_real_poses=True, + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0 +): + """ + OpenX FramePack滑动窗口视频生成 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 OpenX FramePack滑动窗口生成开始...") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加camera编码器 + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 5. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪(适配OpenX数据尺寸) + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 6. 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 7. 预生成完整的camera embedding序列 + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 8. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 9. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - OpenX版本 + framepack_data = prepare_framepack_sliding_window_with_camera( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # 正向预测(带条件) + noise_pred_pos = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # CFG处理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 无条件预测(无camera条件) + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond) + else: + noise_pred = noise_pred_pos + + # Text CFG + if prompt_emb_neg is not None and text_guidance_scale > 1.0: + noise_pred_text_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # Text CFG + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 10. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 OpenX FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + +def main(): + parser = argparse.ArgumentParser(description="OpenX FramePack滑动窗口视频生成") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth", + help="输入编码视频路径") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=24) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=False) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step2000.ckpt", + help="训练好的模型权重路径") + parser.add_argument("--output_path", type=str, + default='openx_results/output_openx_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A video of robotic manipulation task with camera movement") + parser.add_argument("--device", type=str, default="cuda") + + # CFG参数 + parser.add_argument("--use_camera_cfg", action="store_true", default=True, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + args = parser.parse_args() + + print(f"🔧 OpenX FramePack CFG生成设置:") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"OpenX特有特性: camera间隔为4帧,适用于机器人操作任务") + + inference_openx_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + use_real_poses=args.use_real_poses, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_origin.py b/scripts/infer_origin.py new file mode 100644 index 0000000000000000000000000000000000000000..806f1207f89d0efa7b91d009b3494f57c9740f4f --- /dev/null +++ b/scripts/infer_origin.py @@ -0,0 +1,1108 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def replace_dit_model_in_manager(): + """替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + + +def add_moe_components(dit_model, moe_config): + """🔧 添加MoE相关组件 - 修正版本""" + if not hasattr(dit_model, 'moe_config'): + dit_model.moe_config = moe_config + print("✅ 添加了MoE配置到模型") + dit_model.top_k = moe_config.get("top_k", 1) + + # 为每个block动态添加MoE组件 + dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + unified_dim = moe_config.get("unified_dim", 25) + num_experts = moe_config.get("num_experts", 4) + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + dit_model.global_router = nn.Linear(unified_dim, num_experts) + + + for i, block in enumerate(dit_model.blocks): + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=moe_config.get("num_experts", 4), + top_k=moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") + + +def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True,direction="left"): + """为Sekai数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实Sekai camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算Sekai camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + if direction=="left": + print("-----Left-------") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # 持续左转运动模式 + yaw_per_frame = 0.05 # 每帧左转(正角度表示左转) + forward_speed = 0.05 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = -radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + elif direction=="right": + print("------------Right----------") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # 持续左转运动模式 + yaw_per_frame = -0.00 # 每帧左转(正角度表示左转) + forward_speed = 0.1 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.000 # 向圆心的轻微漂移 + pose[0, 3] = radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses): + """为OpenX数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: + print("🔧 使用OpenX真实camera数据") + cam_extrinsic = encoded_data['cam_emb']['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算OpenX camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # OpenX使用4倍间隔,类似sekai但处理更短的序列 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用OpenX合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # OpenX机器人操作运动模式 - 较小的运动幅度 + # 模拟机器人手臂的精细操作运动 + roll_per_frame = 0.02 # 轻微翻滚 + pitch_per_frame = 0.01 # 轻微俯仰 + yaw_per_frame = 0.015 # 轻微偏航 + forward_speed = 0.003 # 较慢的前进速度 + + pose = np.eye(4, dtype=np.float32) + + # 复合旋转 - 模拟机器人手臂的复杂运动 + # 绕X轴旋转(roll) + cos_roll = np.cos(roll_per_frame) + sin_roll = np.sin(roll_per_frame) + # 绕Y轴旋转(pitch) + cos_pitch = np.cos(pitch_per_frame) + sin_pitch = np.sin(pitch_per_frame) + # 绕Z轴旋转(yaw) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + # 简化的复合旋转矩阵(ZYX顺序) + pose[0, 0] = cos_yaw * cos_pitch + pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll + pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll + pose[1, 0] = sin_yaw * cos_pitch + pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll + pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll + pose[2, 0] = -sin_pitch + pose[2, 1] = cos_pitch * sin_roll + pose[2, 2] = cos_pitch * cos_roll + + # 平移 - 模拟机器人操作的精细移动 + pose[0, 3] = forward_speed * 0.5 # X方向轻微移动 + pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动 + pose[2, 3] = -forward_speed # Z方向(深度)主要移动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + +def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames): + """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if scene_info is not None and 'keyframe_poses' in scene_info: + print("🔧 使用NuScenes真实pose数据") + keyframe_poses = scene_info['keyframe_poses'] + + if len(keyframe_poses) == 0: + print("⚠️ NuScenes keyframe_poses为空,使用零pose") + max_needed_frames = max(framepack_needed_frames, 30) + + pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) + + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + # 使用第一个pose作为参考 + reference_pose = keyframe_poses[0] + + max_needed_frames = max(framepack_needed_frames, 30) + + pose_vecs = [] + for i in range(max_needed_frames): + if i < len(keyframe_poses): + current_pose = keyframe_poses[i] + + # 计算相对位移 + translation = torch.tensor( + np.array(current_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + # 计算相对旋转(简化版本) + rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + else: + # 超出范围,使用零pose + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + ], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7] + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用NuScenes合成pose数据") + max_needed_frames = max(framepack_needed_frames, 30) + + # 创建合成运动序列 + pose_vecs = [] + for i in range(max_needed_frames): + # 左转运动模式 - 类似城市驾驶中的左转弯 + angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯) + radius = 15.0 # 较大的转弯半径,更符合汽车转弯 + + # 计算圆弧轨迹上的位置 + x = radius * np.sin(angle) + y = 0.0 # 保持水平面运动 + z = radius * (1 - np.cos(angle)) + + translation = torch.tensor([x, y, z], dtype=torch.float32) + + # 车辆朝向 - 始终沿着轨迹切线方向 + yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角 + # 四元数表示绕Y轴的旋转 + rotation = torch.tensor([ + np.cos(yaw/2), # w (实部) + 0.0, # x + 0.0, # y + np.sin(yaw/2) # z (虚部,绕Y轴) + ], dtype=torch.float32) + + pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz] + pose_vecs.append(pose_vec) + + pose_sequence = torch.stack(pose_vecs, dim=0) + + # 创建mask + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8] + print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49): + """FramePack滑动窗口机制 - MoE版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 MoE Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - 模态类型: {modality_type}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'modality_type': modality_type, # 新增模态类型信息 + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + + +def inference_moe_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="moe/infer_results/output_moe_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + modality_type="sekai", # "sekai" 或 "nuscenes" + use_real_poses=True, + scene_info_path=None, # 对于NuScenes数据集 + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0, + # MoE参数 + moe_num_experts=4, + moe_top_k=2, + moe_hidden_dim=None, + direction="left", + use_gt_prompt=True +): + """ + MoE FramePack滑动窗口视频生成 - 支持多模态 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 MoE FramePack滑动窗口生成开始...") + print(f"模态类型: {modality_type}") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加传统camera编码器(兼容性) + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 添加MoE组件 + moe_config = { + "num_experts": moe_num_experts, + "top_k": moe_top_k, + "hidden_dim": moe_hidden_dim or dim * 2, + "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask + "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask + "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai) + } + add_moe_components(pipe.dit, moe_config) + + # 5. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件 + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 6. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 7. 编码prompt - 支持CFG + if use_gt_prompt and 'prompt_emb' in encoded_data: + print("✅ 使用预编码的GT prompt embedding") + prompt_emb_pos = encoded_data['prompt_emb'] + # 将prompt_emb移到正确的设备和数据类型 + if 'context' in prompt_emb_pos: + prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype) + if 'context_mask' in prompt_emb_pos: + prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype) + + # 如果使用Text CFG,生成负向prompt + if text_guidance_scale > 1.0: + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}") + else: + prompt_emb_neg = None + print("不使用Text CFG") + + # 🔧 打印GT prompt文本(如果有) + if 'prompt' in encoded_data['prompt_emb']: + gt_prompt_text = encoded_data['prompt_emb']['prompt'] + print(f"📝 GT Prompt文本: {gt_prompt_text}") + else: + # 使用传入的prompt参数重新编码 + print(f"🔄 重新编码prompt: {prompt}") + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 8. 加载场景信息(对于NuScenes) + scene_info = None + if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + print(f"加载NuScenes场景信息: {scene_info_path}") + + # 9. 预生成完整的camera embedding序列 + if modality_type == "sekai": + camera_embedding_full = generate_sekai_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses, + direction=direction + ).to(device, dtype=model_dtype) + elif modality_type == "nuscenes": + camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( + scene_info, + 0, + max_history_frames, + 0 + ).to(device, dtype=model_dtype) + elif modality_type == "openx": + camera_embedding_full = generate_openx_camera_embeddings_sliding( + encoded_data, + 0, + max_history_frames, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + else: + raise ValueError(f"不支持的模态类型: {modality_type}") + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 10. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 11. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - MoE版本 + framepack_data = prepare_framepack_sliding_window_with_camera_moe( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + modality_type, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 准备modality_inputs + modality_inputs = {modality_type: camera_embedding} + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + modality_inputs=modality_inputs_uncond, # MoE无条件模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + noise_pred_text_uncond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond, moe_loess= pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred, moe_loess = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + modality_inputs=modality_inputs, # MoE模态输入 + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 12. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + print(f"使用模态: {modality_type}") + + +def main(): + parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone/00500210001_0012150_0012450/encoded_video.pth") + default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth") + #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth") + #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=24) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", default=False) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A car is driving") + parser.add_argument("--device", type=str, default="cuda") + + # 模态类型参数 + parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="nuscenes", + help="模态类型:sekai 或 nuscenes 或 openx") + parser.add_argument("--scene_info_path", type=str, default=None, + help="NuScenes场景信息文件路径(仅用于nuscenes模态)") + + # CFG参数 + parser.add_argument("--use_camera_cfg", default=False, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + # MoE参数 + parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度") + parser.add_argument("--direction", type=str, default="left") + parser.add_argument("--use_gt_prompt", action="store_true", default=False, + help="使用数据集中的ground truth prompt embedding") + + args = parser.parse_args() + + print(f"🔧 MoE FramePack CFG生成设置:") + print(f"模态类型: {args.modality_type}") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"使用GT Prompt: {args.use_gt_prompt}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}") + print(f"DiT{args.dit_path}") + + # 验证NuScenes参数 + if args.modality_type == "nuscenes" and not args.scene_info_path: + print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据") + + inference_moe_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + modality_type=args.modality_type, + use_real_poses=args.use_real_poses, + scene_info_path=args.scene_info_path, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale, + # MoE参数 + moe_num_experts=args.moe_num_experts, + moe_top_k=args.moe_top_k, + moe_hidden_dim=args.moe_hidden_dim, + direction=args.direction, + use_gt_prompt=args.use_gt_prompt + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_recam.py b/scripts/infer_recam.py new file mode 100644 index 0000000000000000000000000000000000000000..292948a9d7f4315b039935352bf422f8964838af --- /dev/null +++ b/scripts/infer_recam.py @@ -0,0 +1,272 @@ +import sys +import torch +import torch.nn as nn +from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import pandas as pd +import torchvision +from PIL import Image +import numpy as np +import json + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +class TextVideoCameraDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, condition_frames=40, target_frames=20): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + self.args = args + self.cam_type = self.args.cam_type + + # 🔧 新增:保存帧数配置 + self.condition_frames = condition_frames + self.target_frames = target_frames + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + def get_relative_pose(self, cam_params): + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + video = self.load_video(path) + if video is None: + raise ValueError(f"{path} is not a valid video.") + num_frames = video.shape[1] + assert num_frames == 81 + data = {"text": text, "video": video, "path": path} + + # load camera + tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json" + with open(tgt_camera_path, 'r') as file: + cam_data = json.load(file) + + # 🔧 修改:生成target_frames长度的相机轨迹 + cam_idx = np.linspace(0, 80, self.target_frames, dtype=int).tolist() # 改为target_frames长度 + traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx] + traj = np.stack(traj).transpose(0, 2, 1) + c2ws = [] + for c2w in traj: + c2w = c2w[:, [1, 2, 0, 3]] + c2w[:3, 1] *= -1. + c2w[:3, 3] /= 100 + c2ws.append(c2w) + tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] + relative_poses = [] + for i in range(len(tgt_cam_params)): + relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) + relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) + pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4] + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12] + data['camera'] = pose_embedding.to(torch.bfloat16) + return data + + def __len__(self): + return len(self.path) + +def parse_args(): + parser = argparse.ArgumentParser(description="ReCamMaster Inference") + parser.add_argument( + "--dataset_path", + type=str, + default="./example_test_data", + help="The path of the Dataset.", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/recam_future_checkpoint/step1000.ckpt", + help="Path to save the model.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./results", + help="Path to save the results.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--cam_type", + type=str, + default=1, + ) + parser.add_argument( + "--cfg_scale", + type=float, + default=5.0, + ) + # 🔧 新增:condition和target帧数参数 + parser.add_argument( + "--condition_frames", + type=int, + default=15, + help="Number of condition frames", + ) + parser.add_argument( + "--target_frames", + type=int, + default=15, + help="Number of target frames to generate", + ) + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + + # 1. Load Wan2.1 pre-trained models + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. Initialize additional modules introduced in ReCamMaster + dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. Load ReCamMaster checkpoint + state_dict = torch.load(args.ckpt_path, map_location="cpu") + pipe.dit.load_state_dict(state_dict, strict=True) + pipe.to("cuda") + pipe.to(dtype=torch.bfloat16) + + output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # 4. Prepare test data (source video, target camera, target trajectory) + dataset = TextVideoCameraDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + args, + condition_frames=args.condition_frames, # 🔧 传递参数 + target_frames=args.target_frames, # 🔧 传递参数 + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 5. Inference + for batch_idx, batch in enumerate(dataloader): + target_text = batch["text"] + source_video = batch["video"] + target_camera = batch["camera"] + + video = pipe( + prompt=target_text, + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + source_video=source_video, + target_camera=target_camera, + cfg_scale=args.cfg_scale, + num_inference_steps=50, + seed=0, + tiled=True, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5) \ No newline at end of file diff --git a/scripts/infer_rlbench.py b/scripts/infer_rlbench.py new file mode 100644 index 0000000000000000000000000000000000000000..72f918a7f6cfebca3b6b14f8a84a771a7a2e3d38 --- /dev/null +++ b/scripts/infer_rlbench.py @@ -0,0 +1,447 @@ +import os +import torch +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import torch.nn as nn + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """ + 从pth文件加载预编码的视频数据 + Args: + pth_path: pth文件路径 + start_frame: 起始帧索引(基于压缩后的latent帧数) + num_frames: 需要的帧数(基于压缩后的latent帧数) + Returns: + condition_latents: [C, T, H, W] 格式的latent tensor + """ + print(f"Loading encoded video from {pth_path}") + + # 加载编码数据 + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + + # 获取latent数据 + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + # 检查帧数是否足够 + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + # 提取指定帧数 + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + """ + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames): + """ + 从实际相机数据生成pose embeddings + Args: + cam_data: 相机外参数据 + start_frame: 起始帧(原始帧索引) + condition_frames: 条件帧数(压缩后) + target_frames: 目标帧数(压缩后) + """ + time_compression_ratio = 4 + total_frames = condition_frames + target_frames + + # 获取相机外参序列 + cam_extrinsic = cam_data # [N, 4, 4] + + # 计算原始帧索引 + start_frame_original = start_frame * time_compression_ratio + end_frame_original = (start_frame + total_frames) * time_compression_ratio + + print(f"Using camera data from frame {start_frame_original} to {end_frame_original}") + + # 计算相对pose + relative_poses = [] + for i in range(total_frames): + frame_idx = start_frame_original + i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + + cam_prev = cam_extrinsic[frame_idx] + + + + relative_poses.append(torch.as_tensor(cam_prev)) # 取前3行 + + print(cam_prev) + # 组装pose embedding + pose_embedding = torch.stack(relative_poses, dim=0) + # print('pose_embedding init:',pose_embedding[0]) + print('pose_embedding:',pose_embedding) + # assert False + + # pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 # condition frames + mask = mask.view(-1, 1) + + # 组合pose和mask + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + + print(f"Generated camera embedding shape: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20): + """ + 根据指定方向生成相机pose序列(合成数据) + """ + time_compression_ratio = 4 + total_frames = condition_frames + target_frames + + poses = [] + + for i in range(total_frames): + t = i / max(1, total_frames - 1) # 0 to 1 + + # 创建变换矩阵 + pose = np.eye(4, dtype=np.float32) + + if direction == "forward": + # 前进:沿z轴负方向移动 + pose[2, 3] = -t * 0.04 + print('forward!') + + elif direction == "backward": + # 后退:沿z轴正方向移动 + pose[2, 3] = t * 2.0 + + elif direction == "left_turn": + # 左转:前进 + 绕y轴旋转 + pose[2, 3] = -t * 0.03 # 前进 + pose[0, 3] = t * 0.02 # 左移 + # 添加旋转 + yaw = t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + + elif direction == "right_turn": + # 右转:前进 + 绕y轴反向旋转 + pose[2, 3] = -t * 0.03 # 前进 + pose[0, 3] = -t * 0.02 # 右移 + # 添加旋转 + yaw = - t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + + poses.append(pose) + + # 计算相对pose + relative_poses = [] + for i in range(len(poses) - 1): + relative_pose = compute_relative_pose(poses[i], poses[i + 1]) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行 + + # 为了匹配模型输入,需要确保帧数正确 + if len(relative_poses) < total_frames: + # 补充最后一帧 + relative_poses.append(relative_poses[-1]) + + pose_embedding = torch.stack(relative_poses[:total_frames], dim=0) + + print('pose_embedding init:',pose_embedding[0]) + + print('pose_embedding:',pose_embedding[-5:]) + + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 # condition frames + mask = mask.view(-1, 1) + + # 组合pose和mask + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + + print(f"Generated {direction} movement poses:") + print(f" Total frames: {total_frames}") + print(f" Camera embedding shape: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def inference_sekai_video_from_pth( + condition_pth_path, + dit_path, + output_path="sekai/infer_results/output_sekai.mp4", + start_frame=0, + condition_frames=10, # 压缩后的帧数 + target_frames=2, # 压缩后的帧数 + device="cuda", + prompt="a robotic arm executing precise manipulation tasks on a clean, organized desk", + direction="forward", + use_real_poses=True +): + """ + 从pth文件进行Sekai视频推理 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + print(f"Setting up models for {direction} movement...") + + # 1. Load models + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # Add camera components to DiT + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(30, dim) # 13维embedding (12D pose + 1D mask) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # Load trained DiT weights + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + pipe.scheduler.set_timesteps(50) + + print("Loading condition video from pth...") + + # Load condition video from pth + condition_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=condition_frames + ) + + condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + print("Processing poses...") + + # 生成相机pose embedding + if use_real_poses and 'cam_emb' in encoded_data: + print("Using real camera poses from data") + camera_embedding = generate_camera_poses_from_data( + encoded_data['cam_emb'], + start_frame=start_frame, + condition_frames=condition_frames, + target_frames=target_frames + ) + else: + print(f"Using synthetic {direction} poses") + camera_embedding = generate_camera_poses( + direction=direction, + target_frames=target_frames, + condition_frames=condition_frames + ) + + + + camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16) + + print(f"Camera embedding shape: {camera_embedding.shape}") + + print("Encoding prompt...") + + # Encode text prompt + prompt_emb = pipe.encode_prompt(prompt) + + print("Generating video...") + + # Generate target latents + batch_size = 1 + channels = condition_latents.shape[1] + latent_height = condition_latents.shape[3] + latent_width = condition_latents.shape[4] + + # 空间裁剪以节省内存(如果需要) + target_height, target_width = 64, 64 + + if latent_height > target_height or latent_width > target_width: + # 中心裁剪 + h_start = (latent_height - target_height) // 2 + w_start = (latent_width - target_width) // 2 + condition_latents = condition_latents[:, :, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + latent_height = target_height + latent_width = target_width + + # Initialize target latents with noise + target_latents = torch.randn( + batch_size, channels, target_frames, latent_height, latent_width, + device=device, dtype=pipe.torch_dtype + ) + + print(f"Condition latents shape: {condition_latents.shape}") + print(f"Target latents shape: {target_latents.shape}") + print(f"Camera embedding shape: {camera_embedding.shape}") + + # Combine condition and target latents + combined_latents = torch.cat([condition_latents, target_latents], dim=2) + print(f"Combined latents shape: {combined_latents.shape}") + + # Prepare extra inputs + extra_input = pipe.prepare_extra_input(combined_latents) + + # Denoising loop + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + print(f"Denoising step {i+1}/{len(timesteps)}") + + # Prepare timestep + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + # Predict noise + with torch.no_grad(): + noise_pred = pipe.dit( + combined_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + **prompt_emb, + **extra_input + ) + + # Update only target part + target_noise_pred = noise_pred[:, :, condition_frames:, :, :] + target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents) + + # Update combined latents + combined_latents[:, :, condition_frames:, :, :] = target_latents + + print("Decoding video...") + + # Decode final video + final_video = torch.cat([condition_latents, target_latents], dim=2) + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + # Save video + print(f"Saving video to {output_path}") + + # Convert to numpy and save + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"Video generation completed! Saved to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH") + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/rlbench/OpenBox_demo_49/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0, + help="Starting frame index (compressed latent frames)") + parser.add_argument("--condition_frames", type=int, default=8, + help="Number of condition frames (compressed latent frames)") + parser.add_argument("--target_frames", type=int, default=8, + help="Number of target frames to generate (compressed latent frames)") + parser.add_argument("--direction", type=str, default="left_turn", + choices=["forward", "backward", "left_turn", "right_turn"], + help="Direction of camera movement (if not using real poses)") + parser.add_argument("--use_real_poses", default=False, + help="Use real camera poses from data") + parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/RLBench-train/step2000_dynamic.ckpt", + help="Path to trained DiT checkpoint") + parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/rlbench/infer_results/output_rl_2.mp4', + help="Output video path") + parser.add_argument("--prompt", type=str, + default="a robotic arm executing precise manipulation tasks on a clean, organized desk", + help="Text prompt for generation") + parser.add_argument("--device", type=str, default="cuda", + help="Device to run inference on") + + args = parser.parse_args() + + # 生成输出路径 + if args.output_path is None: + pth_filename = os.path.basename(args.condition_pth) + name_parts = os.path.splitext(pth_filename) + output_dir = "rlbench/infer_results" + os.makedirs(output_dir, exist_ok=True) + + if args.use_real_poses: + output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + else: + output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + + output_path = os.path.join(output_dir, output_filename) + else: + output_path = args.output_path + + print(f"Input pth: {args.condition_pth}") + print(f"Start frame: {args.start_frame} (compressed)") + print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})") + print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})") + print(f"Use real poses: {args.use_real_poses}") + print(f"Output video will be saved to: {output_path}") + + inference_sekai_video_from_pth( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=output_path, + start_frame=args.start_frame, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + device=args.device, + prompt=args.prompt, + direction=args.direction, + use_real_poses=args.use_real_poses + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_sekai.py b/scripts/infer_sekai.py new file mode 100644 index 0000000000000000000000000000000000000000..6b06d1da1b01559cd117f10ae1e745535678e952 --- /dev/null +++ b/scripts/infer_sekai.py @@ -0,0 +1,497 @@ +import os +import torch +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import torch.nn as nn + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """ + 从pth文件加载预编码的视频数据 + Args: + pth_path: pth文件路径 + start_frame: 起始帧索引(基于压缩后的latent帧数) + num_frames: 需要的帧数(基于压缩后的latent帧数) + Returns: + condition_latents: [C, T, H, W] 格式的latent tensor + """ + print(f"Loading encoded video from {pth_path}") + + # 加载编码数据 + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + + # 获取latent数据 + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + # 检查帧数是否足够 + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + # 提取指定帧数 + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + """ + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames): + """ + 从实际相机数据生成pose embeddings + Args: + cam_data: 相机外参数据 + start_frame: 起始帧(原始帧索引) + condition_frames: 条件帧数(压缩后) + target_frames: 目标帧数(压缩后) + """ + time_compression_ratio = 4 + total_frames = condition_frames + target_frames + + # 获取相机外参序列 + cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4] + + # 计算原始帧索引 + start_frame_original = start_frame * time_compression_ratio + end_frame_original = (start_frame + total_frames) * time_compression_ratio + + print(f"Using camera data from frame {start_frame_original} to {end_frame_original}") + + # 计算相对pose + relative_poses = [] + for i in range(total_frames): + frame_idx = start_frame_original + i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx >= len(cam_extrinsic): + print('out of temporal range!!!') + # 如果超出范围,使用最后一个可用的pose + relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4)) + else: + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行 + + print(cam_prev) + # 组装pose embedding + pose_embedding = torch.stack(relative_poses, dim=0) + # print('pose_embedding init:',pose_embedding[0]) + print('pose_embedding:',pose_embedding) + assert False + + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 # condition frames + mask = mask.view(-1, 1) + + # 组合pose和mask + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + + print(f"Generated camera embedding shape: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20): + """ + 根据指定方向生成相机pose序列(合成数据) + """ + time_compression_ratio = 4 + total_frames = condition_frames + target_frames + + poses = [] + + for i in range(total_frames): + t = i / max(1, total_frames - 1) # 0 to 1 + + # 创建变换矩阵 + pose = np.eye(4, dtype=np.float32) + + if direction == "forward": + # 前进:沿z轴负方向移动 + pose[2, 3] = -t * 0.04 + print('forward!') + + elif direction == "backward": + # 后退:沿z轴正方向移动 + pose[2, 3] = t * 2.0 + + elif direction == "left_turn": + # 左转:前进 + 绕y轴旋转 + pose[2, 3] = -t * 0.03 # 前进 + pose[0, 3] = t * 0.02 # 左移 + # 添加旋转 + yaw = t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + + elif direction == "right_turn": + # 右转:前进 + 绕y轴反向旋转 + pose[2, 3] = -t * 0.03 # 前进 + pose[0, 3] = -t * 0.02 # 右移 + # 添加旋转 + yaw = - t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + + poses.append(pose) + + # 计算相对pose + relative_poses = [] + for i in range(len(poses) - 1): + relative_pose = compute_relative_pose(poses[i], poses[i + 1]) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行 + + # 为了匹配模型输入,需要确保帧数正确 + if len(relative_poses) < total_frames: + # 补充最后一帧 + relative_poses.append(relative_poses[-1]) + + pose_embedding = torch.stack(relative_poses[:total_frames], dim=0) + + print('pose_embedding init:',pose_embedding[0]) + + print('pose_embedding:',pose_embedding[-5:]) + + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 # condition frames + mask = mask.view(-1, 1) + + # 组合pose和mask + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + + print(f"Generated {direction} movement poses:") + print(f" Total frames: {total_frames}") + print(f" Camera embedding shape: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def inference_sekai_video_from_pth( + condition_pth_path, + dit_path, + output_path="sekai/infer_results/output_sekai.mp4", + start_frame=0, + condition_frames=10, # 压缩后的帧数 + target_frames=2, # 压缩后的帧数 + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + direction="forward", + use_real_poses=True +): + """ + 从pth文件进行Sekai视频推理 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + print(f"Setting up models for {direction} movement...") + + # 1. Load models + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # Add camera components to DiT + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) # 13维embedding (12D pose + 1D mask) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # Load trained DiT weights + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + pipe.scheduler.set_timesteps(50) + + print("Loading condition video from pth...") + + # Load condition video from pth + condition_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=condition_frames + ) + + condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + print("Processing poses...") + + # 生成相机pose embedding + if use_real_poses and 'cam_emb' in encoded_data: + print("Using real camera poses from data") + camera_embedding = generate_camera_poses_from_data( + encoded_data['cam_emb'], + start_frame=start_frame, + condition_frames=condition_frames, + target_frames=target_frames + ) + else: + print(f"Using synthetic {direction} poses") + camera_embedding = generate_camera_poses( + direction=direction, + target_frames=target_frames, + condition_frames=condition_frames + ) + + # camera_embedding = torch.tensor([ + # [ 9.9992e-01, 5.7823e-04, -1.2807e-02, -6.4978e-03, -6.1466e-04, + # 1.0000e+00, -2.8406e-03, -7.1422e-04, 1.2806e-02, 2.8482e-03, + # 9.9991e-01, -1.4152e-02, 1.0000e+00], + # [ 9.9993e-01, 5.0678e-04, -1.1601e-02, -5.7938e-03, -5.3597e-04, + # 1.0000e+00, -2.5129e-03, -5.6941e-04, 1.1600e-02, 2.5189e-03, + # 9.9993e-01, -1.4078e-02, 1.0000e+00], + # [ 9.9992e-01, 4.4420e-04, -1.2374e-02, -6.2723e-03, -4.8356e-04, + # 9.9999e-01, -3.1780e-03, -1.0313e-03, 1.2372e-02, 3.1837e-03, + # 9.9992e-01, -1.4170e-02, 1.0000e+00], + # [ 9.9997e-01, 2.6684e-04, -7.1423e-03, -2.9546e-03, -2.7965e-04, + # 1.0000e+00, -1.7922e-03, -2.0437e-04, 7.1418e-03, 1.7942e-03, + # 9.9997e-01, -1.3811e-02, 1.0000e+00], + # [ 9.9999e-01, 1.5524e-04, -4.1128e-03, -9.7896e-04, -1.5948e-04, + # 1.0000e+00, -1.0322e-03, 2.5742e-04, 4.1126e-03, 1.0328e-03, + # 9.9999e-01, -1.3608e-02, 1.0000e+00], + # [ 1.0000e+00, 8.9919e-05, -2.3684e-03, 1.8947e-04, -9.1325e-05, + # 1.0000e+00, -5.9445e-04, 5.2862e-04, 2.3683e-03, 5.9466e-04, + # 1.0000e+00, -1.3490e-02, 1.0000e+00], + # [ 1.0000e+00, 5.1932e-05, -1.3635e-03, 8.8221e-04, -5.2401e-05, + # 1.0000e+00, -3.4229e-04, 6.8774e-04, 1.3635e-03, 3.4236e-04, + # 1.0000e+00, -1.3419e-02, 1.0000e+00], + # [ 1.0000e+00, 2.9971e-05, -7.8533e-04, 1.2923e-03, -3.0129e-05, + # 1.0000e+00, -1.9714e-04, 7.8124e-04, 7.8534e-04, 1.9716e-04, + # 1.0000e+00, -1.3378e-02, 1.0000e+00], + # [ 1.0000e+00, 1.7271e-05, -4.5211e-04, 1.5351e-03, -1.7318e-05, + # 1.0000e+00, -1.1352e-04, 8.3586e-04, 4.5211e-04, 1.1353e-04, + # 1.0000e+00, -1.3353e-02, 1.0000e+00], + # [ 1.0000e+00, 9.9305e-06, -2.5968e-04, 1.6798e-03, -9.9495e-06, + # 1.0000e+00, -6.5163e-05, 8.6798e-04, 2.5970e-04, 6.5163e-05, + # 1.0000e+00, -1.3338e-02, 1.0000e+00], + # [ 1.0000e+00, 1.4484e-05, -3.7806e-04, 1.5971e-03, -1.4521e-05, + # 1.0000e+00, -9.4604e-05, 8.4546e-04, 3.7804e-04, 9.4615e-05, + # 1.0000e+00, -1.3347e-02, 0.0000e+00], + # [ 1.0000e+00, 6.5319e-05, -9.4321e-04, 1.1732e-03, -6.5316e-05, + # 1.0000e+00, 5.4177e-06, 9.2146e-04, 9.4322e-04, -5.3641e-06, + # 1.0000e+00, -1.3372e-02, 0.0000e+00], + # [ 9.9999e-01, 2.5994e-04, -3.9389e-03, -1.0991e-03, -2.6020e-04, + # 1.0000e+00, -6.6082e-05, 8.7861e-04, 3.9388e-03, 6.7103e-05, + # 9.9999e-01, -1.3561e-02, 0.0000e+00], + # [ 9.9998e-01, 2.7008e-04, -6.8774e-03, -3.3641e-03, -2.7882e-04, + # 1.0000e+00, -1.2689e-03, -5.0134e-05, 6.8771e-03, 1.2708e-03, + # 9.9998e-01, -1.3853e-02, 0.0000e+00], + # [ 9.9996e-01, 4.6250e-04, -8.4143e-03, -4.5899e-03, -4.6835e-04, + # 1.0000e+00, -6.9268e-04, 3.9740e-04, 8.4139e-03, 6.9660e-04, + # 9.9996e-01, -1.3917e-02, 0.0000e+00] + #], dtype=torch.bfloat16, device=device) + + camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16) + + print(f"Camera embedding shape: {camera_embedding.shape}") + + print("Encoding prompt...") + + # Encode text prompt + prompt_emb = pipe.encode_prompt(prompt) + + print("Generating video...") + + # Generate target latents + batch_size = 1 + channels = condition_latents.shape[1] + latent_height = condition_latents.shape[3] + latent_width = condition_latents.shape[4] + + # 空间裁剪以节省内存(如果需要) + target_height, target_width = 60, 104 + + if latent_height > target_height or latent_width > target_width: + # 中心裁剪 + h_start = (latent_height - target_height) // 2 + w_start = (latent_width - target_width) // 2 + condition_latents = condition_latents[:, :, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + latent_height = target_height + latent_width = target_width + + # Initialize target latents with noise + target_latents = torch.randn( + batch_size, channels, target_frames, latent_height, latent_width, + device=device, dtype=pipe.torch_dtype + ) + + print(f"Condition latents shape: {condition_latents.shape}") + print(f"Target latents shape: {target_latents.shape}") + print(f"Camera embedding shape: {camera_embedding.shape}") + + # Combine condition and target latents + combined_latents = torch.cat([condition_latents, target_latents], dim=2) + print(f"Combined latents shape: {combined_latents.shape}") + + # Prepare extra inputs + extra_input = pipe.prepare_extra_input(combined_latents) + + # Denoising loop + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + print(f"Denoising step {i+1}/{len(timesteps)}") + + # Prepare timestep + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype) + + # Predict noise + with torch.no_grad(): + noise_pred = pipe.dit( + combined_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + **prompt_emb, + **extra_input + ) + + # Update only target part + target_noise_pred = noise_pred[:, :, condition_frames:, :, :] + target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents) + + # Update combined latents + combined_latents[:, :, condition_frames:, :, :] = target_latents + + print("Decoding video...") + + # Decode final video + final_video = torch.cat([condition_latents, target_latents], dim=2) + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + # Save video + print(f"Saving video to {output_path}") + + # Convert to numpy and save + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"Video generation completed! Saved to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH") + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0, + help="Starting frame index (compressed latent frames)") + parser.add_argument("--condition_frames", type=int, default=8, + help="Number of condition frames (compressed latent frames)") + parser.add_argument("--target_frames", type=int, default=8, + help="Number of target frames to generate (compressed latent frames)") + parser.add_argument("--direction", type=str, default="left_turn", + choices=["forward", "backward", "left_turn", "right_turn"], + help="Direction of camera movement (if not using real poses)") + parser.add_argument("--use_real_poses", default=False, + help="Use real camera poses from data") + parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/sekai_walking_noise/step14000_dynamic.ckpt", + help="Path to trained DiT checkpoint") + parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/sekai/infer_noise_results/output_sekai_right_turn.mp4', + help="Output video path") + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world", + help="Text prompt for generation") + parser.add_argument("--device", type=str, default="cuda", + help="Device to run inference on") + + args = parser.parse_args() + + # 生成输出路径 + if args.output_path is None: + pth_filename = os.path.basename(args.condition_pth) + name_parts = os.path.splitext(pth_filename) + output_dir = "sekai/infer_results" + os.makedirs(output_dir, exist_ok=True) + + if args.use_real_poses: + output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + else: + output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + + output_path = os.path.join(output_dir, output_filename) + else: + output_path = args.output_path + + print(f"Input pth: {args.condition_pth}") + print(f"Start frame: {args.start_frame} (compressed)") + print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})") + print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})") + print(f"Use real poses: {args.use_real_poses}") + print(f"Output video will be saved to: {output_path}") + + inference_sekai_video_from_pth( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=output_path, + start_frame=args.start_frame, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + device=args.device, + prompt=args.prompt, + direction=args.direction, + use_real_poses=args.use_real_poses + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_sekai_framepack.py b/scripts/infer_sekai_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3ff650fa81c8532e1362cd00f17132a68635e1 --- /dev/null +++ b/scripts/infer_sekai_framepack.py @@ -0,0 +1,675 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def replace_dit_model_in_manager(): + """替换DiT模型类为FramePack版本""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelFuture) + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + +def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧""" + time_compression_ratio = 4 + + # 🔧 计算FramePack实际需要的camera帧数 + # FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 🔧 确保生成足够长的camera序列 + # 需要考虑:当前历史位置 + FramePack所需的完整结构 + max_needed_frames = max( + start_frame + current_history_length + new_frames, # 基础需求 + framepack_needed_frames, # FramePack结构需求 + 30 # 最小保证长度 + ) + + print(f"🔧 计算camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 🔧 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用合成camera数据") + + # 🔧 确保合成数据也有足够长度 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成合成camera帧数: {max_needed_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 🔧 持续左转运动模式 + # 每帧旋转一个固定角度,同时前进 + yaw_per_frame = -0.05 # 每帧左转(正角度表示左转) + forward_speed = 0.005 # 每帧前进距离 + + # 计算当前累积角度 + current_yaw = i * yaw_per_frame + + # 创建相对变换矩阵(从第i帧到第i+1帧的变换) + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 可选:添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49): + """🔧 FramePack滑动窗口机制 - 修正camera mask更新逻辑""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 🔧 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 🔧 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 🔧 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() # clone to avoid modifying original + + # 🔧 关键修正:根据当前history length重新设置mask + # combined_camera的结构对应: [1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames] + # 前19帧对应clean latents,后面对应target + + # 清空所有mask,重新设置 + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + # 根据clean_latents的填充逻辑,确定哪些位置应该是condition + available_frames = min(T, 19) + start_pos = 19 - available_frames + + # 对应的camera位置也应该标记为condition + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + # target部分保持为0(已经在上面设置) + + print(f"🔧 Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - Condition mask (前19帧): {combined_camera[:19, -1].cpu().tolist()}") + print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[19:, -1].cpu().tolist()}") + # 其余处理逻辑保持不变... + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, # 🔧 现在包含正确更新的mask + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + +def inference_sekai_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + use_real_poses=True, + synthetic_direction="forward", + # 🔧 新增CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=7.5 +): + """ + 🔧 FramePack滑动窗口视频生成 - 支持Camera CFG + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 FramePack滑动窗口生成开始...") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}") + print(f"使用真实姿态: {use_real_poses}") + if not use_real_poses: + print(f"合成camera方向: {synthetic_direction}") + + # 1-3. 模型初始化和组件添加(保持不变) + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + add_framepack_components(pipe.dit) + + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 4. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + # 编码positive prompt + prompt_emb_pos = pipe.encode_prompt(prompt) + # 编码negative prompt (空字符串) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 预生成完整的camera embedding序列 + camera_embedding_full = generate_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 🔧 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + # 创建零camera embedding(无条件) + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 + framepack_data = prepare_framepack_sliding_window_with_camera( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 🔧 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # 🔧 CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + # 还需要计算text无条件预测 + noise_pred_text_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 7. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + +def main(): + parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG") + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=40) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=False) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world") + parser.add_argument("--device", type=str, default="cuda") + + # 🔧 新增CFG参数 + parser.add_argument("--use_camera_cfg", default=True, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + args = parser.parse_args() + + print(f"🔧 FramePack CFG生成设置:") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + + inference_sekai_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + use_real_poses=args.use_real_poses, + # 🔧 CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_sekai_framepack_4.py b/scripts/infer_sekai_framepack_4.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecac4ff1718d83f2fc63c9dfea7351f47fbdf59 --- /dev/null +++ b/scripts/infer_sekai_framepack_4.py @@ -0,0 +1,682 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def replace_dit_model_in_manager(): + """替换DiT模型类为FramePack版本""" + from diffsynth.models.wan_video_dit_4 import WanModelFuture4 + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelFuture4) + print(f"✅ 替换了模型类: {name} -> WanModelFuture4") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + +def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧""" + time_compression_ratio = 4 + + # 🔧 计算FramePack实际需要的camera帧数 + # FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 🔧 确保生成足够长的camera序列 + # 需要考虑:当前历史位置 + FramePack所需的完整结构 + max_needed_frames = max( + start_frame + current_history_length + new_frames, # 基础需求 + framepack_needed_frames, # FramePack结构需求 + 30 # 最小保证长度 + ) + + print(f"🔧 计算camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 计算当前帧在原始序列中的位置 + frame_idx = i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 🔧 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用合成camera数据") + + # 🔧 确保合成数据也有足够长度 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成合成camera帧数: {max_needed_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # 🔧 持续左转运动模式 + # 每帧旋转一个固定角度,同时前进 + yaw_per_frame = -0.05 # 每帧左转(正角度表示左转) + forward_speed = 0.005 # 每帧前进距离 + + # 计算当前累积角度 + current_yaw = i * yaw_per_frame + + # 创建相对变换矩阵(从第i帧到第i+1帧的变换) + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴左转) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(在旋转后的局部坐标系中前进) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + + # 可选:添加轻微的向心运动,模拟圆形轨迹 + radius_drift = 0.002 # 向圆心的轻微漂移 + pose[0, 3] = radius_drift # 局部X轴负方向(向左) + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49): + """🔧 FramePack滑动窗口机制 - 支持起始4帧+最后1帧的clean_latents""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 🔧 固定索引结构:起始4帧 + 最后1帧 = 5帧clean_latents + total_indices_length = 1 + 16 + 2 + 5 + target_frames_to_generate # 修改:clean_latents现在是5帧 + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 5, target_frames_to_generate] # 修改:clean_latents部分改为5 + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + + # clean_latents结构:起始4帧 + 最后1帧 + clean_latent_indices = clean_latent_1x_indices # 现在是5帧,包含起始4帧+最后1帧 + + # 🔧 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 🔧 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 🔧 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前24帧根据实际历史长度决定(1+16+2+5) + if T > 0: + # 根据clean_latents的填充逻辑,确定哪些位置应该是condition + available_frames = min(T, 24) # 修改:现在前24帧对应clean latents + start_pos = 24 - available_frames + + # 对应的camera位置也应该标记为condition + combined_camera[start_pos:24, -1] = 1.0 # 修改:前24帧对应condition + + # target部分保持为0(已经在上面设置) + + print(f"🔧 Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + print(f" - Condition mask (前24帧): {combined_camera[:24, -1].cpu().tolist()}") # 修改:24帧 + print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[24:, -1].cpu().tolist()}") + + # 处理clean latents - 现在clean_latents是5帧:起始4帧+最后1帧 + clean_latents_combined = torch.zeros(C, 24, H, W, dtype=history_latents.dtype, device=history_latents.device) # 修改:24帧 + + if T > 0: + available_frames = min(T, 24) # 修改:24帧 + start_pos = 24 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:23, :, :] # 修改:5帧clean latents + + # 构建clean_latents:起始4帧 + 最后1帧 + if T >= 5: + # 如果历史足够,取起始4帧+最后1帧 + start_latent = history_latents[:, 0:4, :, :] # 起始4帧 + last_latent = history_latents[:, -1:, :, :] # 最后1帧 + clean_latents = torch.cat([start_latent, last_latent], dim=1) # 5帧 + elif T > 0: + # 如果历史不足5帧,用0填充+最后1帧 + clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device) + # 从后往前填充历史帧 + clean_latents[:, -T:, :, :] = history_latents + else: + # 没有历史,全部用0 + clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, # 现在是5帧:起始4帧+最后1帧 + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + +def inference_sekai_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + use_real_poses=True, + synthetic_direction="forward", + # 🔧 新增CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=7.5 +): + """ + 🔧 FramePack滑动窗口视频生成 - 支持Camera CFG + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 FramePack滑动窗口生成开始...") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}") + print(f"使用真实姿态: {use_real_poses}") + if not use_real_poses: + print(f"合成camera方向: {synthetic_direction}") + + # 1-3. 模型初始化和组件添加(保持不变) + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + add_framepack_components(pipe.dit) + + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 4. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + # 编码positive prompt + prompt_emb_pos = pipe.encode_prompt(prompt) + # 编码negative prompt (空字符串) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 预生成完整的camera embedding序列 + camera_embedding_full = generate_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 🔧 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + # 创建零camera embedding(无条件) + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 + framepack_data = prepare_framepack_sliding_window_with_camera( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 🔧 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # 🔧 CFG推理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 条件预测(有camera) + noise_pred_cond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # 无条件预测(无camera) + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # 如果同时使用Text CFG + if text_guidance_scale > 1.0 and prompt_emb_neg: + # 还需要计算text无条件预测 + noise_pred_text_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + # 应用Text CFG到已经应用Camera CFG的结果 + noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) + + elif text_guidance_scale > 1.0 and prompt_emb_neg: + # 只使用Text CFG + noise_pred_cond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) + + else: + # 标准推理(无CFG) + noise_pred = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 7. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + +def main(): + parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG") + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=60) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=True) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack_4/step34290_framepack.ckpt") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world") + parser.add_argument("--device", type=str, default="cuda") + + # 🔧 新增CFG参数 + parser.add_argument("--use_camera_cfg", default=False, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + args = parser.parse_args() + + print(f"🔧 FramePack CFG生成设置:") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + + inference_sekai_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + use_real_poses=args.use_real_poses, + # 🔧 CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_sekai_framepack_test.py b/scripts/infer_sekai_framepack_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a61083e8e3535403297c9743f860097c11dd7c34 --- /dev/null +++ b/scripts/infer_sekai_framepack_test.py @@ -0,0 +1,551 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import imageio +import json +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import argparse +from torchvision.transforms import v2 +from einops import rearrange +import copy + + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """ + 从pth文件加载预编码的视频数据 + Args: + pth_path: pth文件路径 + start_frame: 起始帧索引(基于压缩后的latent帧数) + num_frames: 需要的帧数(基于压缩后的latent帧数) + Returns: + condition_latents: [C, T, H, W] 格式的latent tensor + """ + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +def prepare_framepack_inputs(full_latents, condition_frames, target_frames, start_frame=0): + """🔧 准备FramePack风格的多尺度输入""" + # 确保有batch维度 + if len(full_latents.shape) == 4: # [C, T, H, W] + full_latents = full_latents.unsqueeze(0) # -> [1, C, T, H, W] + squeeze_batch = True + else: + squeeze_batch = False + + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + target_start = start_frame + condition_frames + target_end = target_start + target_frames + latent_indices = torch.arange(target_start, target_end) + main_latents = full_latents[:, :, latent_indices, :, :] + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = torch.tensor([start_frame, start_frame + condition_frames - 1]) + clean_latents = full_latents[:, :, clean_latent_indices, :, :] + + # 🔧 2x条件帧(最后2帧) + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices = torch.full((2,), -1, dtype=torch.long) + + if condition_frames >= 2: + actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 2), + start_frame + condition_frames) + start_pos = 2 - len(actual_indices) + clean_latents_2x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :] + clean_latent_2x_indices[start_pos:] = actual_indices + + # 🔧 4x条件帧(最多16帧) + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices = torch.full((16,), -1, dtype=torch.long) + + if condition_frames >= 1: + actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 16), + start_frame + condition_frames) + start_pos = 16 - len(actual_indices) + clean_latents_4x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :] + clean_latent_4x_indices[start_pos:] = actual_indices + + # 移除batch维度(如果原来没有) + if squeeze_batch: + main_latents = main_latents.squeeze(0) + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + } + + +def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames): + """从实际相机数据生成pose embeddings""" + time_compression_ratio = 4 + total_frames = condition_frames + target_frames + + cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4] + start_frame_original = start_frame * time_compression_ratio + + print(f"Using camera data from frame {start_frame_original}") + + # 计算相对pose + relative_poses = [] + for i in range(total_frames): + frame_idx = start_frame_original + i * time_compression_ratio + next_frame_idx = frame_idx + time_compression_ratio + + if next_frame_idx >= len(cam_extrinsic): + print('Out of temporal range, using last available pose') + relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4)) + else: + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + + relative_pose = compute_relative_pose(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 # condition frames + mask = mask.view(-1, 1) + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + print(f"Generated camera embedding shape: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def generate_synthetic_camera_poses(direction="forward", target_frames=10, condition_frames=20): + """根据指定方向生成相机pose序列(合成数据)""" + total_frames = condition_frames + target_frames + poses = [] + + for i in range(total_frames): + t = i / max(1, total_frames - 1) + pose = np.eye(4, dtype=np.float32) + + if direction == "forward": + pose[2, 3] = -t * 0.04 + elif direction == "backward": + pose[2, 3] = t * 2.0 + elif direction == "left_turn": + pose[2, 3] = -t * 0.03 + pose[0, 3] = t * 0.02 + yaw = t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + elif direction == "right_turn": + pose[2, 3] = -t * 0.03 + pose[0, 3] = -t * 0.02 + yaw = -t * 1 + pose[0, 0] = np.cos(yaw) + pose[0, 2] = np.sin(yaw) + pose[2, 0] = -np.sin(yaw) + pose[2, 2] = np.cos(yaw) + + poses.append(pose) + + # 计算相对pose + relative_poses = [] + for i in range(len(poses) - 1): + relative_pose = compute_relative_pose(poses[i], poses[i + 1]) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) + + if len(relative_poses) < total_frames: + relative_poses.append(relative_poses[-1]) + + pose_embedding = torch.stack(relative_poses[:total_frames], dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12] + + # 添加mask信息 + mask = torch.zeros(total_frames, dtype=torch.float32) + mask[:condition_frames] = 1.0 + mask = mask.view(-1, 1) + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13] + print(f"Generated {direction} movement poses: {camera_embedding.shape}") + + return camera_embedding.to(torch.bfloat16) + + +def replace_dit_model_in_manager(): + """替换DiT模型类为FramePack版本""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelFuture) + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + # 🔧 确保输入和权重的数据类型匹配 + if scale == "1x": + x = x.to(self.proj.weight.dtype) + return self.proj(x) + elif scale == "2x": + x = x.to(self.proj_2x.weight.dtype) + return self.proj_2x(x) + elif scale == "4x": + x = x.to(self.proj_4x.weight.dtype) + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + # 🔧 修复:使用模型参数的dtype而不是模型的dtype属性 + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + +def inference_sekai_framepack_from_pth( + condition_pth_path, + dit_path, + output_path="sekai/infer_results/output_sekai_framepack.mp4", + start_frame=0, + condition_frames=10, + target_frames=2, + device="cuda", + prompt="A video of a scene shot using a pedestrian's front camera while walking", + direction="forward", + use_real_poses=True +): + """ + FramePack风格的Sekai视频推理 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"Setting up FramePack models for {direction} movement...") + + # 1. 替换模型类并加载模型 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加camera components和FramePack components + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 3. 加载训练的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + pipe.dit = pipe.dit.to(dtype=model_dtype) + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + print("Loading condition video from pth...") + + # 4. 加载条件视频数据 + condition_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=condition_frames + ) + + print("Preparing FramePack inputs...") + + # 5. 🔧 准备FramePack风格的多尺度输入 + full_latents = encoded_data['latents'] + framepack_inputs = prepare_framepack_inputs( + full_latents, condition_frames, target_frames, start_frame + ) + + # 🔧 转换为正确的设备和数据类型,确保与DiT模型一致 + for key in framepack_inputs: + if torch.is_tensor(framepack_inputs[key]): + framepack_inputs[key] = framepack_inputs[key].to(device, dtype=model_dtype) + + print("Processing poses...") + + # 6. 生成相机pose embedding + if use_real_poses and 'cam_emb' in encoded_data: + print("Using real camera poses from data") + camera_embedding = generate_camera_poses_from_data( + encoded_data['cam_emb'], + start_frame=start_frame, + condition_frames=condition_frames, + target_frames=target_frames + ) + else: + print(f"Using synthetic {direction} poses") + camera_embedding = generate_synthetic_camera_poses( + direction=direction, + target_frames=target_frames, + condition_frames=condition_frames + ) + + camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=model_dtype) + print("Encoding prompt...") + + # 7. 编码文本提示 + prompt_emb = pipe.encode_prompt(prompt) + print("Generating video...") + + # 8. 生成目标latents + batch_size = 1 + channels = framepack_inputs['latents'].shape[0] # 现在latents没有batch维度 + latent_height = framepack_inputs['latents'].shape[2] + latent_width = framepack_inputs['latents'].shape[3] + + # 空间裁剪以节省内存 + target_height, target_width = 60, 104 + + if latent_height > target_height or latent_width > target_width: + h_start = (latent_height - target_height) // 2 + w_start = (latent_width - target_width) // 2 + + # 裁剪所有inputs + for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']: + if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]): + framepack_inputs[key] = framepack_inputs[key][:, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + + latent_height = target_height + latent_width = target_width + + # 为推理添加batch维度 + for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']: + if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]): + framepack_inputs[key] = framepack_inputs[key].unsqueeze(0) + + # 🔧 修复:为索引张量添加batch维度并确保正确的数据类型 + for key in ['latent_indices', 'clean_latent_indices', 'clean_latent_2x_indices', 'clean_latent_4x_indices']: + if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]): + # 确保索引是long类型,并且在CPU上 + framepack_inputs[key] = framepack_inputs[key].long().cpu().unsqueeze(0) + + # 初始化target latents with noise + target_latents = torch.randn( + batch_size, channels, target_frames, latent_height, latent_width, + device=device, dtype=model_dtype # 🔧 使用模型的dtype + ) + + print(f"FramePack inputs:") + for key, value in framepack_inputs.items(): + if torch.is_tensor(value): + print(f" {key}: {value.shape} {value.dtype}") + else: + print(f" {key}: {value}") + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Target latents shape: {target_latents.shape}") + + # 9. 准备额外输入 + extra_input = pipe.prepare_extra_input(target_latents) + + # 10. 🔧 FramePack风格的去噪循环 + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + print(f"Denoising step {i+1}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + # 🔧 使用FramePack风格的forward调用 + with torch.no_grad(): + noise_pred = pipe.dit( + target_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + # FramePack参数 + latent_indices=framepack_inputs['latent_indices'], + clean_latents=framepack_inputs['clean_latents'], + clean_latent_indices=framepack_inputs['clean_latent_indices'], + clean_latents_2x=framepack_inputs['clean_latents_2x'], + clean_latent_2x_indices=framepack_inputs['clean_latent_2x_indices'], + clean_latents_4x=framepack_inputs['clean_latents_4x'], + clean_latent_4x_indices=framepack_inputs['clean_latent_4x_indices'], + **prompt_emb, + **extra_input + ) + + # 更新target latents + target_latents = pipe.scheduler.step(noise_pred, timestep, target_latents) + + print("Decoding video...") + + # 11. 解码最终视频 + # 拼接condition和target用于解码 + condition_for_decode = framepack_inputs['clean_latents'][:, :, -1:, :, :] # 取最后一帧作为条件 + final_video = torch.cat([condition_for_decode, target_latents], dim=2) + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + # 12. 保存视频 + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"FramePack video generation completed! Saved to {output_path}") + +def main(): + parser = argparse.ArgumentParser(description="Sekai FramePack Video Generation Inference from PTH") + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth") + parser.add_argument("--start_frame", type=int, default=0, + help="Starting frame index (compressed latent frames)") + parser.add_argument("--condition_frames", type=int, default=8, + help="Number of condition frames (compressed latent frames)") + parser.add_argument("--target_frames", type=int, default=8, + help="Number of target frames to generate (compressed latent frames)") + parser.add_argument("--direction", type=str, default="left_turn", + choices=["forward", "backward", "left_turn", "right_turn"], + help="Direction of camera movement (if not using real poses)") + parser.add_argument("--use_real_poses", action="store_true", default=False, + help="Use real camera poses from data") + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step24000_framepack.ckpt", + help="Path to trained FramePack DiT checkpoint") + parser.add_argument("--output_path", type=str, + default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack.mp4', + help="Output video path") + parser.add_argument("--prompt", type=str, + default="A drone flying scene in a game world", + help="Text prompt for generation") + parser.add_argument("--device", type=str, default="cuda", + help="Device to run inference on") + + args = parser.parse_args() + + # 生成输出路径 + if args.output_path is None: + pth_filename = os.path.basename(args.condition_pth) + name_parts = os.path.splitext(pth_filename) + output_dir = "sekai/infer_framepack_results" + os.makedirs(output_dir, exist_ok=True) + + if args.use_real_poses: + output_filename = f"{name_parts[0]}_framepack_real_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + else: + output_filename = f"{name_parts[0]}_framepack_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4" + + output_path = os.path.join(output_dir, output_filename) + else: + output_path = args.output_path + + print(f"🔧 FramePack Inference Settings:") + print(f"Input pth: {args.condition_pth}") + print(f"Start frame: {args.start_frame} (compressed)") + print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})") + print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})") + print(f"Use real poses: {args.use_real_poses}") + print(f"Direction: {args.direction}") + print(f"Output video will be saved to: {output_path}") + + inference_sekai_framepack_from_pth( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=output_path, + start_frame=args.start_frame, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + device=args.device, + prompt=args.prompt, + direction=args.direction, + use_real_poses=args.use_real_poses + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/infer_spatialvid.py b/scripts/infer_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..b07c6a0b3ce9b76b9b98da6b30b5a50c0791e4cf --- /dev/null +++ b/scripts/infer_spatialvid.py @@ -0,0 +1,608 @@ +import torch +import torch.nn as nn +import numpy as np +import os +import json +import imageio +import argparse +from PIL import Image +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +from scipy.spatial.transform import Rotation as R + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): + """从pth文件加载预编码的视频数据""" + print(f"Loading encoded video from {pth_path}") + + encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") + full_latents = encoded_data['latents'] # [C, T, H, W] + + print(f"Full latents shape: {full_latents.shape}") + print(f"Extracting frames {start_frame} to {start_frame + num_frames}") + + if start_frame + num_frames > full_latents.shape[1]: + raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") + + condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] + print(f"Extracted condition latents shape: {condition_latents.shape}") + + return condition_latents, encoded_data + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +def add_framepack_components(dit_model): + """添加FramePack相关组件""" + if not hasattr(dit_model, 'clean_x_embedder'): + inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) + model_dtype = next(dit_model.parameters()).dtype + dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) + print("✅ 添加了FramePack的clean_x_embedder组件") + +def generate_spatialvid_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True): + """为SpatialVid数据集生成camera embeddings - 滑动窗口版本""" + time_compression_ratio = 4 + + # 计算FramePack实际需要的camera帧数 + framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames + + if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: + print("🔧 使用真实SpatialVid camera数据") + cam_extrinsic = cam_data['extrinsic'] + + # 确保生成足够长的camera序列 + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 计算SpatialVid camera序列长度:") + print(f" - 基础需求: {start_frame + current_history_length + new_frames}") + print(f" - FramePack需求: {framepack_needed_frames}") + print(f" - 最终生成: {max_needed_frames}") + + relative_poses = [] + for i in range(max_needed_frames): + # SpatialVid特有:每隔1帧而不是4帧 + frame_idx = i + next_frame_idx = frame_idx + 1 + + if next_frame_idx < len(cam_extrinsic): + cam_prev = cam_extrinsic[frame_idx] + cam_next = cam_extrinsic[next_frame_idx] + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_poses.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 超出范围,使用零运动 + print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动") + relative_poses.append(torch.zeros(3, 4)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + # 从start_frame到current_history_length标记为condition + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 SpatialVid真实camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + + else: + print("🔧 使用SpatialVid合成camera数据") + + max_needed_frames = max( + start_frame + current_history_length + new_frames, + framepack_needed_frames, + 30 + ) + + print(f"🔧 生成SpatialVid合成camera帧数: {max_needed_frames}") + relative_poses = [] + for i in range(max_needed_frames): + # SpatialVid室内行走模式 - 轻微的左右摆动 + 前进 + yaw_per_frame = 0.03 * np.sin(i * 0.1) # 左右摆动 + forward_speed = 0.008 # 每帧前进距离 + + pose = np.eye(4, dtype=np.float32) + + # 旋转矩阵(绕Y轴摆动) + cos_yaw = np.cos(yaw_per_frame) + sin_yaw = np.sin(yaw_per_frame) + + pose[0, 0] = cos_yaw + pose[0, 2] = sin_yaw + pose[2, 0] = -sin_yaw + pose[2, 2] = cos_yaw + + # 平移(前进 + 轻微的上下晃动) + pose[2, 3] = -forward_speed # 局部Z轴负方向(前进) + pose[1, 3] = 0.002 * np.sin(i * 0.15) # 轻微的上下晃动 + + relative_pose = pose[:3, :] + relative_poses.append(torch.as_tensor(relative_pose)) + + pose_embedding = torch.stack(relative_poses, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + + # 创建对应长度的mask序列 + mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) + condition_end = min(start_frame + current_history_length, max_needed_frames) + mask[start_frame:condition_end] = 1.0 + + camera_embedding = torch.cat([pose_embedding, mask], dim=1) + print(f"🔧 SpatialVid合成camera embedding shape: {camera_embedding.shape}") + return camera_embedding.to(torch.bfloat16) + +def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49): + """FramePack滑动窗口机制 - SpatialVid版本""" + # history_latents: [C, T, H, W] 当前的历史latents + C, T, H, W = history_latents.shape + + # 固定索引结构(这决定了需要的camera帧数) + total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate + indices = torch.arange(0, total_indices_length) + split_sizes = [1, 16, 2, 1, target_frames_to_generate] + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ + indices.split(split_sizes, dim=0) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) + + # 检查camera长度是否足够 + if camera_embedding_full.shape[0] < total_indices_length: + shortage = total_indices_length - camera_embedding_full.shape[0] + padding = torch.zeros(shortage, camera_embedding_full.shape[1], + dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) + camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) + + # 从完整camera序列中选取对应部分 + combined_camera = camera_embedding_full[:total_indices_length, :].clone() + + # 根据当前history length重新设置mask + combined_camera[:, -1] = 0.0 # 先全部设为target (0) + + # 设置condition mask:前19帧根据实际历史长度决定 + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition + + print(f"🔧 SpatialVid Camera mask更新:") + print(f" - 历史帧数: {T}") + print(f" - 有效condition帧数: {available_frames if T > 0 else 0}") + + # 处理latents + clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) + + if T > 0: + available_frames = min(T, 19) + start_pos = 19 - available_frames + clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] + + clean_latents_4x = clean_latents_combined[:, 0:16, :, :] + clean_latents_2x = clean_latents_combined[:, 16:18, :, :] + clean_latents_1x = clean_latents_combined[:, 18:19, :, :] + + if T > 0: + start_latent = history_latents[:, 0:1, :, :] + else: + start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) + + clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) + + return { + 'latent_indices': latent_indices, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + 'camera_embedding': combined_camera, + 'current_length': T, + 'next_length': T + target_frames_to_generate + } + +def inference_spatialvid_framepack_sliding_window( + condition_pth_path, + dit_path, + output_path="spatialvid_results/output_spatialvid_framepack_sliding.mp4", + start_frame=0, + initial_condition_frames=8, + frames_per_generation=4, + total_frames_to_generate=32, + max_history_frames=49, + device="cuda", + prompt="A man walking through indoor spaces with a first-person view", + use_real_poses=True, + # CFG参数 + use_camera_cfg=True, + camera_guidance_scale=2.0, + text_guidance_scale=1.0 +): + """ + SpatialVid FramePack滑动窗口视频生成 + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + print(f"🔧 SpatialVid FramePack滑动窗口生成开始...") + print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") + print(f"Text guidance scale: {text_guidance_scale}") + + # 1. 模型初始化 + replace_dit_model_in_manager() + + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. 添加camera编码器 + dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. 添加FramePack组件 + add_framepack_components(pipe.dit) + + # 4. 加载训练好的权重 + dit_state_dict = torch.load(dit_path, map_location="cpu") + pipe.dit.load_state_dict(dit_state_dict, strict=True) + pipe = pipe.to(device) + model_dtype = next(pipe.dit.parameters()).dtype + + if hasattr(pipe.dit, 'clean_x_embedder'): + pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) + + pipe.scheduler.set_timesteps(50) + + # 5. 加载初始条件 + print("Loading initial condition frames...") + initial_latents, encoded_data = load_encoded_video_from_pth( + condition_pth_path, + start_frame=start_frame, + num_frames=initial_condition_frames + ) + + # 空间裁剪 + target_height, target_width = 60, 104 + C, T, H, W = initial_latents.shape + + if H > target_height or W > target_width: + h_start = (H - target_height) // 2 + w_start = (W - target_width) // 2 + initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] + H, W = target_height, target_width + + history_latents = initial_latents.to(device, dtype=model_dtype) + + print(f"初始history_latents shape: {history_latents.shape}") + + # 6. 编码prompt - 支持CFG + if text_guidance_scale > 1.0: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = pipe.encode_prompt("") + print(f"使用Text CFG,guidance scale: {text_guidance_scale}") + else: + prompt_emb_pos = pipe.encode_prompt(prompt) + prompt_emb_neg = None + print("不使用Text CFG") + + # 7. 预生成完整的camera embedding序列 + camera_embedding_full = generate_spatialvid_camera_embeddings_sliding( + encoded_data.get('cam_emb', None), + 0, + max_history_frames, + 0, + 0, + use_real_poses=use_real_poses + ).to(device, dtype=model_dtype) + + print(f"完整camera序列shape: {camera_embedding_full.shape}") + + # 8. 为Camera CFG创建无条件的camera embedding + if use_camera_cfg: + camera_embedding_uncond = torch.zeros_like(camera_embedding_full) + print(f"创建无条件camera embedding用于CFG") + + # 9. 滑动窗口生成循环 + total_generated = 0 + all_generated_frames = [] + + while total_generated < total_frames_to_generate: + current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) + print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}") + print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}") + + # FramePack数据准备 - SpatialVid版本 + framepack_data = prepare_framepack_sliding_window_with_camera( + history_latents, + current_generation, + camera_embedding_full, + start_frame, + max_history_frames + ) + + # 准备输入 + clean_latents = framepack_data['clean_latents'].unsqueeze(0) + clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) + clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) + camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) + + # 为CFG准备无条件camera embedding + if use_camera_cfg: + camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) + + # 索引处理 + latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() + clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() + clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() + clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() + + # 初始化要生成的latents + new_latents = torch.randn( + 1, C, current_generation, H, W, + device=device, dtype=model_dtype + ) + + extra_input = pipe.prepare_extra_input(new_latents) + + print(f"Camera embedding shape: {camera_embedding.shape}") + print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") + + # 去噪循环 - 支持CFG + timesteps = pipe.scheduler.timesteps + + for i, timestep in enumerate(timesteps): + if i % 10 == 0: + print(f" 去噪步骤 {i}/{len(timesteps)}") + + timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) + + with torch.no_grad(): + # 正向预测(带条件) + noise_pred_pos = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # CFG处理 + if use_camera_cfg and camera_guidance_scale > 1.0: + # 无条件预测(无camera条件) + noise_pred_uncond = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding_uncond_batch, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_pos, + **extra_input + ) + + # Camera CFG + noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond) + else: + noise_pred = noise_pred_pos + + # Text CFG + if prompt_emb_neg is not None and text_guidance_scale > 1.0: + noise_pred_neg = pipe.dit( + new_latents, + timestep=timestep_tensor, + cam_emb=camera_embedding, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb_neg, + **extra_input + ) + + noise_pred = noise_pred_neg + text_guidance_scale * (noise_pred - noise_pred_neg) + + new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) + + # 更新历史 + new_latents_squeezed = new_latents.squeeze(0) + history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) + + # 维护滑动窗口 + if history_latents.shape[1] > max_history_frames: + first_frame = history_latents[:, 0:1, :, :] + recent_frames = history_latents[:, -(max_history_frames-1):, :, :] + history_latents = torch.cat([first_frame, recent_frames], dim=1) + print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧") + + print(f"更新后history_latents shape: {history_latents.shape}") + + all_generated_frames.append(new_latents_squeezed) + total_generated += current_generation + + print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧") + + # 10. 解码和保存 + print("\n🔧 解码生成的视频...") + + all_generated = torch.cat(all_generated_frames, dim=1) + final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) + + print(f"最终视频shape: {final_video.shape}") + + decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + + print(f"Saving video to {output_path}") + + video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() + video_np = (video_np * 0.5 + 0.5).clip(0, 1) + video_np = (video_np * 255).astype(np.uint8) + + with imageio.get_writer(output_path, fps=20) as writer: + for frame in video_np: + writer.append_data(frame) + + print(f"🔧 SpatialVid FramePack滑动窗口生成完成! 保存到: {output_path}") + print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧") + +def main(): + parser = argparse.ArgumentParser(description="SpatialVid FramePack滑动窗口视频生成") + + # 基础参数 + parser.add_argument("--condition_pth", type=str, + default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth", + help="输入编码视频路径") + parser.add_argument("--start_frame", type=int, default=0) + parser.add_argument("--initial_condition_frames", type=int, default=16) + parser.add_argument("--frames_per_generation", type=int, default=8) + parser.add_argument("--total_frames_to_generate", type=int, default=16) + parser.add_argument("--max_history_frames", type=int, default=100) + parser.add_argument("--use_real_poses", action="store_true", default=True) + parser.add_argument("--dit_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_framepack_random/step50.ckpt", + help="训练好的模型权重路径") + parser.add_argument("--output_path", type=str, + default='spatialvid_results/output_spatialvid_framepack_sliding.mp4') + parser.add_argument("--prompt", type=str, + default="A man walking through indoor spaces with a first-person view") + parser.add_argument("--device", type=str, default="cuda") + + # CFG参数 + parser.add_argument("--use_camera_cfg", action="store_true", default=True, + help="使用Camera CFG") + parser.add_argument("--camera_guidance_scale", type=float, default=2.0, + help="Camera guidance scale for CFG") + parser.add_argument("--text_guidance_scale", type=float, default=1.0, + help="Text guidance scale for CFG") + + args = parser.parse_args() + + print(f"🔧 SpatialVid FramePack CFG生成设置:") + print(f"Camera CFG: {args.use_camera_cfg}") + if args.use_camera_cfg: + print(f"Camera guidance scale: {args.camera_guidance_scale}") + print(f"Text guidance scale: {args.text_guidance_scale}") + print(f"SpatialVid特有特性: camera间隔为1帧") + + inference_spatialvid_framepack_sliding_window( + condition_pth_path=args.condition_pth, + dit_path=args.dit_path, + output_path=args.output_path, + start_frame=args.start_frame, + initial_condition_frames=args.initial_condition_frames, + frames_per_generation=args.frames_per_generation, + total_frames_to_generate=args.total_frames_to_generate, + max_history_frames=args.max_history_frames, + device=args.device, + prompt=args.prompt, + use_real_poses=args.use_real_poses, + # CFG参数 + use_camera_cfg=args.use_camera_cfg, + camera_guidance_scale=args.camera_guidance_scale, + text_guidance_scale=args.text_guidance_scale + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/inference_recammaster.py b/scripts/inference_recammaster.py new file mode 100644 index 0000000000000000000000000000000000000000..ffe3445d3a80603b4e33eb33457e489c66f40368 --- /dev/null +++ b/scripts/inference_recammaster.py @@ -0,0 +1,257 @@ +import sys +import torch +import torch.nn as nn +from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import pandas as pd +import torchvision +from PIL import Image +import numpy as np +import json + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +class TextVideoCameraDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + self.args = args + self.cam_type = self.args.cam_type + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + + def get_relative_pose(self, cam_params): + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + video = self.load_video(path) + if video is None: + raise ValueError(f"{path} is not a valid video.") + num_frames = video.shape[1] + assert num_frames == 81 + data = {"text": text, "video": video, "path": path} + + # load camera + tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json" + with open(tgt_camera_path, 'r') as file: + cam_data = json.load(file) + + cam_idx = list(range(num_frames))[::4] + traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx] + traj = np.stack(traj).transpose(0, 2, 1) + c2ws = [] + for c2w in traj: + c2w = c2w[:, [1, 2, 0, 3]] + c2w[:3, 1] *= -1. + c2w[:3, 3] /= 100 + c2ws.append(c2w) + tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] + relative_poses = [] + for i in range(len(tgt_cam_params)): + relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) + relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) + pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + data['camera'] = pose_embedding.to(torch.bfloat16) + return data + + + def __len__(self): + return len(self.path) + +def parse_args(): + parser = argparse.ArgumentParser(description="ReCamMaster Inference") + parser.add_argument( + "--dataset_path", + type=str, + default="./example_test_data", + help="The path of the Dataset.", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="./models/ReCamMaster/checkpoints/step20000.ckpt", + help="Path to save the model.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./results", + help="Path to save the results.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--cam_type", + type=str, + default=1, + ) + parser.add_argument( + "--cfg_scale", + type=float, + default=5.0, + ) + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + + # 1. Load Wan2.1 pre-trained models + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([ + "/share_zhuyixuan05/zhuyixuan05/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/diffusion_pytorch_model.safetensors", + "/share_zhuyixuan05/zhuyixuan05/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/models_t5_umt5-xxl-enc-bf16.pth", + "/share_zhuyixuan05/zhuyixuan05/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/Wan2.1_VAE.pth", + ]) + pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") + + # 2. Initialize additional modules introduced in ReCamMaster + dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + # 3. Load ReCamMaster checkpoint + state_dict = torch.load(args.ckpt_path, map_location="cpu") + pipe.dit.load_state_dict(state_dict, strict=True) + pipe.to("cuda") + pipe.to(dtype=torch.bfloat16) + + output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # 4. Prepare test data (source video, target camera, target trajectory) + dataset = TextVideoCameraDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + args, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 5. Inference + for batch_idx, batch in enumerate(dataloader): + target_text = batch["text"] + source_video = batch["video"] + target_camera = batch["camera"] + + video = pipe( + prompt=target_text, + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + source_video=source_video, + target_camera=target_camera, + cfg_scale=args.cfg_scale, + num_inference_steps=50, + seed=0, tiled=True + ) + save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5) \ No newline at end of file diff --git a/scripts/init_test.py b/scripts/init_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7203a4153b7b9ac59d799252b1d68437dc14823e --- /dev/null +++ b/scripts/init_test.py @@ -0,0 +1 @@ +from diffsynth.pipelines.wan_video_recam_future import WanVideoReCamMasterFuturePipeline diff --git a/scripts/nuscenes_keyframes_processor.py b/scripts/nuscenes_keyframes_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..873bc1b8aded03780fec025aee778fb25e5a82f2 --- /dev/null +++ b/scripts/nuscenes_keyframes_processor.py @@ -0,0 +1,252 @@ +import os +import json +import numpy as np +from nuscenes.nuscenes import NuScenes +import multiprocessing as mp +from tqdm import tqdm +import cv2 +from PIL import Image + +# Configuration +VERSION = 'v1.0-trainval' +DATA_ROOT = '/share_zhuyixuan05/public_datasets/nuscenes/nuscenes-download/data' +OUTPUT_DIR = '/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic' +NUM_PROCESSES = 30 +PROCESSED_SCENES_FILE = os.path.join(OUTPUT_DIR, 'processed_scenes_dynamic.txt') +CAMERA_CHANNELS = ['CAM_FRONT'] + +def calculate_relative_pose(pose_current, pose_reference): + """计算相对于参考pose的相对位置和旋转""" + trans_ref = np.array(pose_reference['translation']) + trans_cur = np.array(pose_current['translation']) + + # 计算相对位置 + relative_translation = trans_cur - trans_ref + + relative_pose = { + 'relative_translation': relative_translation.tolist(), + 'current_rotation': pose_current['rotation'], + 'reference_rotation': pose_reference['rotation'], + 'timestamp': pose_current['timestamp'] + } + + return relative_pose + +def extract_full_scene_with_keyframes(nusc, scene_token, scene_name, output_dir, channel): + """提取完整场景并记录关键帧位置""" + scene_record = nusc.get('scene', scene_token) + current_sample_token = scene_record['first_sample_token'] + + # 收集所有sample_data tokens、ego_poses和关键帧标记 + all_sd_tokens = [] + all_ego_poses = [] + keyframe_indices = [] # 记录哪些帧是关键帧 + frame_index = 0 + + while current_sample_token: + sample_record = nusc.get('sample', current_sample_token) + + if channel in sample_record['data']: + current_sd_token = sample_record['data'][channel] + + # 从keyframe开始,收集所有sample_data + while current_sd_token: + sd_record = nusc.get('sample_data', current_sd_token) + all_sd_tokens.append(current_sd_token) + + # 记录ego_pose和关键帧位置 + if sd_record['is_key_frame']: + ego_pose_record = nusc.get('ego_pose', sd_record['ego_pose_token']) + all_ego_poses.append(ego_pose_record) + keyframe_indices.append(frame_index) + else: + all_ego_poses.append(None) + + frame_index += 1 + current_sd_token = sd_record['next'] if sd_record['next'] != '' else None + + break + + current_sample_token = sample_record['next'] if sample_record['next'] != '' else None + + # 检查是否有足够的帧数和关键帧 + total_frames = len(all_sd_tokens) + num_keyframes = len(keyframe_indices) + + if total_frames < 30 or num_keyframes < 3: # 至少需要30帧和3个关键帧 + print(f"Scene {scene_name}: Insufficient frames ({total_frames}) or keyframes ({num_keyframes}), skipping...") + return 0 + + # 创建场景目录 + scene_dir = os.path.join(output_dir, 'scenes', f"{scene_name}_{channel}") + os.makedirs(scene_dir, exist_ok=True) + + # 渲染完整视频 + video_path = os.path.join(scene_dir, 'full_video.mp4') + success = render_full_video(nusc, all_sd_tokens, video_path) + + if not success: + print(f"Failed to render video for {scene_name}") + return 0 + + # 处理关键帧的poses + keyframe_poses = [] + valid_keyframes = [] + + for i, frame_idx in enumerate(keyframe_indices): + pose = all_ego_poses[frame_idx] + if pose is not None: + keyframe_poses.append(pose) + valid_keyframes.append(frame_idx) + + # 保存完整的场景信息 + scene_info = { + 'scene_name': scene_name, + 'channel': channel, + 'total_frames': total_frames, + 'keyframe_indices': valid_keyframes, + 'keyframe_poses': keyframe_poses, + 'sample_data_tokens': all_sd_tokens, + 'video_path': 'full_video.mp4' + } + + with open(os.path.join(scene_dir, 'scene_info.json'), 'w') as f: + json.dump(scene_info, f, indent=2) + + print(f"Processed scene {scene_name}: {total_frames} frames, {len(valid_keyframes)} keyframes") + return 1 + +def render_full_video(nusc, sd_tokens, output_path): + """渲染完整视频序列""" + if not sd_tokens: + return False + + try: + # 获取第一帧来确定视频尺寸 + first_sd = nusc.get('sample_data', sd_tokens[0]) + first_image_path = os.path.join(nusc.dataroot, first_sd['filename']) + first_image = Image.open(first_image_path) + width, height = first_image.size + + # 设置视频编码器 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, 10.0, (width, height)) + + for sd_token in sd_tokens: + sd_record = nusc.get('sample_data', sd_token) + image_path = os.path.join(nusc.dataroot, sd_record['filename']) + + if os.path.exists(image_path): + image = cv2.imread(image_path) + if image is not None: + out.write(image) + + out.release() + return True + + except Exception as e: + print(f"Error rendering video to {output_path}: {str(e)}") + return False + +def process_scene_dynamic(args): + """处理单个场景,生成动态长度数据""" + scene_token, channels = args + nusc = NuScenes(version=VERSION, dataroot=DATA_ROOT, verbose=False) + scene_record = nusc.get('scene', scene_token) + scene_name = scene_record['name'] + + success_channels = [] + total_scenes = 0 + + try: + for channel in channels: + # 检查是否已经处理过 + scene_dir = os.path.join(OUTPUT_DIR, 'scenes', f"{scene_name}_{channel}") + if os.path.exists(os.path.join(scene_dir, 'scene_info.json')): + print(f"Scene {scene_name} {channel} already processed, skipping...") + success_channels.append(channel) + continue + + # 提取完整场景 + scenes_count = extract_full_scene_with_keyframes(nusc, scene_token, scene_name, OUTPUT_DIR, channel) + + if scenes_count > 0: + success_channels.append(channel) + total_scenes += scenes_count + else: + print(f"Failed to process scene {scene_name} {channel}") + + except Exception as e: + print(f"Error processing {scene_name} ({scene_token}): {str(e)}") + + return scene_token, success_channels, total_scenes + +def get_processed_scenes(): + """读取处理记录""" + processed = {} + if os.path.exists(PROCESSED_SCENES_FILE): + with open(PROCESSED_SCENES_FILE, 'r') as f: + for line in f: + line = line.strip() + if not line or ':' not in line: + continue + token, channels_str = line.split(':', 1) + processed[token] = set(channels_str.split(',')) + return processed + +def main(): + # 创建输出目录 + os.makedirs(OUTPUT_DIR, exist_ok=True) + os.makedirs(os.path.join(OUTPUT_DIR, 'scenes'), exist_ok=True) + + # 初始化数据集 + nusc = NuScenes(version=VERSION, dataroot=DATA_ROOT, verbose=True) + all_scenes = {s['token']: s for s in nusc.scene} + + # 加载处理记录 + processed = get_processed_scenes() + + # 生成任务列表 + tasks = [] + for scene_token in all_scenes: + processed_channels = processed.get(scene_token, set()) + remaining = [ch for ch in CAMERA_CHANNELS if ch not in processed_channels] + if remaining: + tasks.append((scene_token, remaining)) + + print(f"Total scenes: {len(all_scenes)}") + print(f"Pending tasks: {len(tasks)}") + print("Processing full scenes with keyframe tracking...") + + if not tasks: + print("All scenes already processed!") + return + + # 创建进程池 + total_scenes_created = 0 + with mp.Pool(processes=NUM_PROCESSES) as pool: + results = [] + for res in tqdm(pool.imap_unordered(process_scene_dynamic, tasks), + total=len(tasks), + desc="Processing Scenes"): + results.append(res) + + # 更新处理记录 + updated = get_processed_scenes() + for scene_token, success_chs, scenes_count in results: + if scene_token not in updated: + updated[scene_token] = set() + updated[scene_token].update(success_chs) + total_scenes_created += scenes_count + + # 写入最终记录 + with open(PROCESSED_SCENES_FILE, 'w') as f: + for token, chs in updated.items(): + f.write(f"{token}:{','.join(sorted(chs))}\n") + + print(f"\nProcessing completed!") + print(f"Total scenes created: {total_scenes_created}") + print(f"Output directory: {OUTPUT_DIR}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/occupy.py b/scripts/occupy.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f8d3ddb454f6f2d231a3465b7a75d622aae6dc --- /dev/null +++ b/scripts/occupy.py @@ -0,0 +1,106 @@ +import torch +import time +import argparse +from threading import Thread + +def gpu_worker(gpu_id, duration, tensor_size): + """单个GPU的工作线程,负责持续进行张量运算""" + try: + # 设置当前线程使用的GPU + device = torch.device(f"cuda:{gpu_id}") + torch.cuda.set_device(device) + + # 打印GPU信息 + gpu_name = torch.cuda.get_device_name(device) + print(f"GPU {gpu_id} 启动: {gpu_name}") + + # 创建大随机张量 + tensor_a = torch.randn(tensor_size, tensor_size, device=device) + tensor_b = torch.randn(tensor_size, tensor_size, device=device) + + # 预热GPU + for _ in range(10): + result = torch.matmul(tensor_a, tensor_b) + torch.cuda.synchronize(device) + + # 开始持续运算 + start_time = time.time() + iterations = 0 + + while time.time() - start_time < duration: + # 矩阵乘法运算 + result = torch.matmul(tensor_a, tensor_b) + + # 定期更新张量避免优化 + if iterations % 100 == 0: + tensor_a = 0.999 * tensor_a + 0.001 * torch.randn_like(tensor_a) + tensor_b = 0.999 * tensor_b + 0.001 * torch.randn_like(tensor_b) + + iterations += 1 + + # 每10秒打印一次状态 + if iterations % 1000 == 0: + elapsed = time.time() - start_time + print(f"GPU {gpu_id}: 已运行 {elapsed:.1f} 秒, 完成 {iterations} 次迭代") + + # 短暂同步确保计算完成 + if iterations % 100 == 0: + torch.cuda.synchronize(device) + + # 计算结束统计 + elapsed = time.time() - start_time + print(f"GPU {gpu_id} 完成: 总时间 {elapsed:.1f} 秒, 总迭代 {iterations} 次, " + f"平均每秒 {iterations/elapsed:.2f} 次") + + except Exception as e: + print(f"GPU {gpu_id} 出错: {str(e)}") + + finally: + # 清理内存 + torch.cuda.empty_cache() + +def multi_gpu_stress_test(duration, tensor_size, use_gpus=None): + """多GPU压力测试主函数""" + # 检查可用GPU数量 + available_gpus = torch.cuda.device_count() + if available_gpus == 0: + print("错误: 未检测到可用GPU") + return + + # 确定要使用的GPU + if use_gpus is None: + use_gpus = list(range(available_gpus)) + else: + # 验证GPU ID有效性 + use_gpus = [g for g in use_gpus if 0 <= g < available_gpus] + if not use_gpus: + print("错误: 没有有效的GPU ID") + return + + print(f"检测到 {available_gpus} 张GPU,将使用 {len(use_gpus)} 张: {use_gpus}") + + # 为每张GPU创建并启动线程 + threads = [] + for gpu_id in use_gpus: + thread = Thread(target=gpu_worker, args=(gpu_id, duration, tensor_size)) + threads.append(thread) + thread.start() + + # 等待所有线程完成 + for thread in threads: + thread.join() + + print("所有GPU测试完成") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='多GPU压力测试程序') + parser.add_argument('--duration', type=int, default=6000000, + help='测试持续时间(秒),默认60秒') + parser.add_argument('--size', type=int, default=4096, + help='每张GPU上的张量大小,默认4096x4096') + parser.add_argument('--gpus', type=int, nargs='+', + help=f'指定要使用的GPU ID,如 --gpus 0 1 2 3 4 5 6 7') + args = parser.parse_args() + + # 运行多GPU测试 + multi_gpu_stress_test(args.duration, args.size, args.gpus) diff --git a/scripts/pose_classifier.py b/scripts/pose_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2b590463d870a8bf6fd97efc10f60018b890b2f0 --- /dev/null +++ b/scripts/pose_classifier.py @@ -0,0 +1,247 @@ +import torch +import numpy as np +from typing import List, Tuple + +class PoseClassifier: + """将pose参数分类为前后左右四个类别,正确使用rotation数据判断转弯""" + + def __init__(self): + # 定义四个方向的类别 + self.FORWARD = 0 + self.BACKWARD = 1 + self.LEFT_TURN = 2 + self.RIGHT_TURN = 3 + + self.class_names = ['forward', 'backward', 'left_turn', 'right_turn'] + + def classify_pose_sequence(self, pose_sequence: torch.Tensor) -> torch.Tensor: + """ + 对pose序列进行分类,基于相对于reference的pose变化 + Args: + pose_sequence: [num_frames, 7] (relative_translation + relative_quaternion) + 这里的pose都是相对于reference帧的相对变换 + Returns: + classifications: [num_frames] 类别标签 + """ + # 提取平移部分 [num_frames, 3] 和旋转部分 [num_frames, 4] + translations = pose_sequence[:, :3] # 相对于reference的位移 + rotations = pose_sequence[:, 3:7] # 相对于reference的旋转 [w, x, y, z] + + # 分类每一帧 - 都是相对于reference帧的变化 + classifications = [] + for i in range(len(pose_sequence)): + # 🔧 修改:每一帧都基于相对于reference的变化进行分类 + relative_translation = translations[i] # 相对于reference的位移 + relative_rotation = rotations[i] # 相对于reference的旋转 + + class_label = self._classify_single_pose(relative_translation, relative_rotation) + classifications.append(class_label) + + return torch.tensor(classifications, dtype=torch.long) + + def _classify_single_pose(self, relative_translation: torch.Tensor, + relative_rotation: torch.Tensor) -> int: + """ + 对单个pose进行分类,基于相对于reference的变化 + Args: + relative_translation: [3] 相对于reference的位移变化 + relative_rotation: [4] 相对于reference的旋转四元数 [w, x, y, z] + """ + # 🔧 关键:从相对旋转四元数提取yaw角度 + yaw_angle = self._quaternion_to_yaw(relative_rotation) + + # 🔧 计算前进/后退(主要看x方向的位移) + forward_movement = -relative_translation[0].item() # x负方向为前进 + + # 🔧 设置阈值 + yaw_threshold = 0.05 # 约2.9度,可以调整 + movement_threshold = 0.01 # 位移阈值 + + # 🔧 优先判断转弯(基于相对于reference的yaw角度) + if abs(yaw_angle) > yaw_threshold: + if yaw_angle > 0: + return self.LEFT_TURN # 正yaw角度为左转 + else: + return self.RIGHT_TURN # 负yaw角度为右转 + + # 🔧 如果没有明显转弯,判断前进后退(基于相对位移) + if abs(forward_movement) > movement_threshold: + if forward_movement > 0: + return self.FORWARD + else: + return self.BACKWARD + + # 🔧 如果位移和旋转都很小,判断为前进(静止时的默认状态) + return self.FORWARD + + def _quaternion_to_yaw(self, q: torch.Tensor) -> float: + """ + 从四元数提取yaw角度(绕z轴旋转) + Args: + q: [4] 四元数 [w, x, y, z] + Returns: + yaw: yaw角度(弧度) + """ + try: + # 转换为numpy数组进行计算 + q_np = q.detach().cpu().numpy() + + # 🔧 确保四元数是单位四元数 + norm = np.linalg.norm(q_np) + if norm > 1e-8: + q_np = q_np / norm + else: + # 如果四元数接近零,返回0角度 + return 0.0 + + w, x, y, z = q_np + + # 🔧 计算yaw角度:atan2(2*(w*z + x*y), 1 - 2*(y^2 + z^2)) + yaw = np.arctan2(2.0 * (w*z + x*y), 1.0 - 2.0 * (y*y + z*z)) + + return float(yaw) + + except Exception as e: + print(f"Error computing yaw from quaternion: {e}") + return 0.0 + + def create_class_embedding(self, class_labels: torch.Tensor, embed_dim: int = 512) -> torch.Tensor: + """ + 为类别标签创建embedding + Args: + class_labels: [num_frames] 类别标签 + embed_dim: embedding维度 + Returns: + embeddings: [num_frames, embed_dim] + """ + num_classes = 4 + num_frames = len(class_labels) + + # 🔧 创建更有意义的embedding,不同类别有不同的特征 + # 使用预定义的方向向量 + direction_vectors = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], # forward: 主要x分量 + [-1.0, 0.0, 0.0, 0.0], # backward: 负x分量 + [0.0, 1.0, 0.0, 0.0], # left_turn: 主要y分量 + [0.0, -1.0, 0.0, 0.0], # right_turn: 负y分量 + ], dtype=torch.float32) + + # One-hot编码 + one_hot = torch.zeros(num_frames, num_classes) + one_hot.scatter_(1, class_labels.unsqueeze(1), 1) + + # 基于方向向量的基础embedding + base_embeddings = one_hot @ direction_vectors # [num_frames, 4] + + # 扩展到目标维度 + if embed_dim > 4: + # 使用线性变换扩展 + expand_matrix = torch.randn(4, embed_dim) * 0.1 + # 保持方向性 + expand_matrix[:4, :4] = torch.eye(4) + embeddings = base_embeddings @ expand_matrix + else: + embeddings = base_embeddings[:, :embed_dim] + + return embeddings + + def get_class_name(self, class_id: int) -> str: + """获取类别名称""" + return self.class_names[class_id] + + def analyze_pose_sequence(self, pose_sequence: torch.Tensor) -> dict: + """ + 分析pose序列,返回详细的统计信息 + Args: + pose_sequence: [num_frames, 7] (translation + quaternion) + Returns: + analysis: 包含统计信息的字典 + """ + classifications = self.classify_pose_sequence(pose_sequence) + + # 统计各类别数量 + class_counts = torch.bincount(classifications, minlength=4) + + # 计算连续运动段 + motion_segments = [] + if len(classifications) > 0: + current_class = classifications[0].item() + segment_start = 0 + + for i in range(1, len(classifications)): + if classifications[i].item() != current_class: + motion_segments.append({ + 'class': self.get_class_name(current_class), + 'start_frame': segment_start, + 'end_frame': i-1, + 'duration': i - segment_start + }) + current_class = classifications[i].item() + segment_start = i + + # 添加最后一个段 + motion_segments.append({ + 'class': self.get_class_name(current_class), + 'start_frame': segment_start, + 'end_frame': len(classifications)-1, + 'duration': len(classifications) - segment_start + }) + + # 计算总体运动信息 + translations = pose_sequence[:, :3] + if len(translations) > 1: + # 计算累积距离(相对于reference的总移动距离) + total_distance = torch.norm(translations[-1] - translations[0]) + else: + total_distance = torch.tensor(0.0) + + analysis = { + 'total_frames': len(pose_sequence), + 'class_distribution': { + self.get_class_name(i): count.item() + for i, count in enumerate(class_counts) + }, + 'motion_segments': motion_segments, + 'total_distance': total_distance.item(), + 'classifications': classifications + } + + return analysis + + def debug_single_pose(self, relative_translation: torch.Tensor, + relative_rotation: torch.Tensor) -> dict: + """ + 调试单个pose的分类过程 + Args: + relative_translation: [3] 相对位移 + relative_rotation: [4] 相对旋转四元数 + Returns: + debug_info: 调试信息字典 + """ + yaw_angle = self._quaternion_to_yaw(relative_rotation) + forward_movement = -relative_translation[0].item() + + yaw_threshold = 0.05 + movement_threshold = 0.01 + + classification = self._classify_single_pose(relative_translation, relative_rotation) + + debug_info = { + 'relative_translation': relative_translation.tolist(), + 'relative_rotation': relative_rotation.tolist(), + 'yaw_angle_rad': yaw_angle, + 'yaw_angle_deg': np.degrees(yaw_angle), + 'forward_movement': forward_movement, + 'yaw_threshold': yaw_threshold, + 'movement_threshold': movement_threshold, + 'classification': self.get_class_name(classification), + 'classification_id': classification, + 'decision_process': { + 'abs_yaw_exceeds_threshold': abs(yaw_angle) > yaw_threshold, + 'abs_movement_exceeds_threshold': abs(forward_movement) > movement_threshold, + 'yaw_direction': 'left' if yaw_angle > 0 else 'right' if yaw_angle < 0 else 'none', + 'movement_direction': 'forward' if forward_movement > 0 else 'backward' if forward_movement < 0 else 'none' + } + } + + return debug_info \ No newline at end of file diff --git a/scripts/rebuttal.md b/scripts/rebuttal.md new file mode 100644 index 0000000000000000000000000000000000000000..f5d8d986c2a40bf42a006b201617ae5564060792 --- /dev/null +++ b/scripts/rebuttal.md @@ -0,0 +1,6 @@ +We sincerely thank all reviewers for their constructive feedback and valuable suggestions. We are encouraged by the positive reception of our work, with all reviewers finding it well-written. Reviewer ZCNg described our approach as “innovative”, Reviewer mTLA acknowledged that it “effectively improves the performance”, and Reviewer LCnj highlighted its value for “various distillation studies”. We have carefully addressed the raised concerns and clarified potential confusions by incorporating corresponding modifications into our paper (with revisions highlighted in blue). + +**RnKJ**: + - **Inconsistency in the core motivation and training design.** + + - **Mismatch between metrics and visual realism.** \ No newline at end of file diff --git a/scripts/setup.py b/scripts/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a429f7cb470590ca1b1ccf02ae561e1929762b --- /dev/null +++ b/scripts/setup.py @@ -0,0 +1,30 @@ +import os +from setuptools import setup, find_packages +import pkg_resources + +# Path to the requirements file +requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt") + +# Read the requirements from the requirements file +if os.path.exists(requirements_path): + with open(requirements_path, 'r') as f: + install_requires = [str(r) for r in pkg_resources.parse_requirements(f)] +else: + install_requires = [] + +setup( + name="diffsynth", + version="1.1.2", + description="Enjoy the magic of Diffusion models!", + author="Artiprocher", + packages=find_packages(), + install_requires=install_requires, + include_package_data=True, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + package_data={"diffsynth": ["tokenizer_configs/**/**/*.*"]}, + python_requires='>=3.6', +) diff --git a/scripts/spatialvid_pose_test.py b/scripts/spatialvid_pose_test.py new file mode 100644 index 0000000000000000000000000000000000000000..973725b88d9b9c0ef74f94d179a7732fcedf38f4 --- /dev/null +++ b/scripts/spatialvid_pose_test.py @@ -0,0 +1,159 @@ + +import torch + +import os + + +import os +import json +import torch +import numpy as np + +from scipy.spatial.transform import Rotation as R + +import pdb + +def compute_relative_pose_matrix2( + pose_a, + pose_b +) -> np.ndarray: + """ + 计算两个相机姿态(7元数组)之间的相对位姿,输出3×4相机矩阵 [R_rel | t_rel] + + 数学定义:相对位姿描述“从姿态A到姿态B的变换”,即: + - 若点P在姿态A的相机坐标系下坐标为P_A,在姿态B的相机坐标系下坐标为P_B, + 则满足 P_B = R_rel @ P_A + t_rel(R_rel为相对旋转矩阵,t_rel为相对平移向量) + + 参数: + pose_a: 参考姿态A,形状(7,)的数组/list,格式[tx_a, ty_a, tz_a, qx_a, qy_a, qz_a, qw_a] + - tx_a/ty_a/tz_a: 姿态A在世界坐标系的位置(平移向量) + - qx_a/qy_a/qz_a/qw_a: 姿态A的朝向(单位四元数,右手坐标系) + pose_b: 目标姿态B,格式与pose_a完全一致 + + 返回: + 3×4的相对位姿相机矩阵,前3列是3×3相对旋转矩阵R_rel,第4列是3×1相对平移向量t_rel + + 异常: + ValueError: 若输入姿态形状/格式不正确,或四元数非单位四元数 + """ + # -------------------------- + # 1. 输入校验(确保数据格式正确) + # -------------------------- + # 转换为numpy数组并检查形状 + pose_a = np.asarray(pose_a, dtype=np.float64) + pose_b = np.asarray(pose_b, dtype=np.float64) + + if pose_a.shape != (7,): + raise ValueError(f"姿态A需为(7,)数组,实际输入形状{pose_a.shape}") + if pose_b.shape != (7,): + raise ValueError(f"姿态B需为(7,)数组,实际输入形状{pose_b.shape}") + + # 分离平移向量和四元数 + t_a = pose_a[:3] # 姿态A的世界坐标:[tx_a, ty_a, tz_a] + q_a = pose_a[3:] # 姿态A的四元数:[qx_a, qy_a, qz_a, qw_a] + t_b = pose_b[:3] # 姿态B的世界坐标 + q_b = pose_b[3:] # 姿态B的四元数 + + # 检查四元数是否为单位四元数(避免旋转计算错误) + q_a_norm = np.linalg.norm(q_a) + q_b_norm = np.linalg.norm(q_b) + if not np.isclose(q_a_norm, 1.0, atol=1e-4): + raise ValueError(f"姿态A的四元数非单位四元数,模长为{q_a_norm:.6f}(需接近1.0)") + if not np.isclose(q_b_norm, 1.0, atol=1e-4): + raise ValueError(f"姿态B的四元数非单位四元数,模长为{q_b_norm:.6f}(需接近1.0)") + + # -------------------------- + # 2. 计算相对旋转矩阵 R_rel + # -------------------------- + # 将四元数转换为Rotation对象(scipy自动处理右手坐标系) + rot_a = R.from_quat(q_a) # 姿态A的旋转矩阵(世界→A相机的旋转) + rot_b = R.from_quat(q_b) # 姿态B的旋转矩阵(世界→B相机的旋转) + + # 相对旋转 = 姿态B的旋转 × 姿态A旋转的逆(单位旋转矩阵的逆=转置) + # 数学逻辑:R_rel 描述“A相机坐标系→B相机坐标系”的旋转 + rot_rel = rot_b * rot_a.inv() + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵, dtype=np.float64 + + # -------------------------- + # 3. 计算相对平移向量 t_rel + # -------------------------- + # 数学推导:t_rel = R_rel @ (-rot_a.inv() @ t_a) + (rot_b.inv() @ t_b) + # 简化后:t_rel = rot_a.inv().as_matrix().T @ (t_b - t_a) + # 物理意义:在A相机坐标系下,B相机相对于A相机的位置 + R_a_T = rot_a.inv().as_matrix().T # 姿态A旋转矩阵的逆=转置(单位矩阵性质) + t_rel = R_a_T @ (t_b - t_a) # 3×1相对平移向量 + + # -------------------------- + # 4. 组合为3×4相机矩阵 + # -------------------------- + # 拼接旋转矩阵(3×3)和平移向量(3×1),形成3×4矩阵 + relative_cam_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_cam_matrix + + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +encoded_data = torch.load( + os.path.join('/share_zhuyixuan05/zhuyixuan05/spatialvid/fdb39216-0d15-5f0f-a78f-c599913a4a2e_0000600_0000900', "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + +cam_data_ori = np.load('./poses.npy') + +cam_data_seq_ori = cam_data_ori +print(cam_data_seq_ori.shape) +print('---------------------------') +cam_data = encoded_data['cam_emb'] + +cam_data_seq = cam_data_seq_ori # +cam_data_seq_inter = cam_data['extrinsic'] +print(cam_data_seq_inter.shape) +keyframe_original_idx = list(range(10)) + +relative_cams = [] + +for idx in keyframe_original_idx: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx+1] + + relative_cam = compute_relative_pose_matrix2(cam_prev,cam_next) + + relative_cams.append(torch.as_tensor(relative_cam[:3,:])) +relative_cam = compute_relative_pose_matrix2(cam_data_seq_inter[0],cam_data_seq_inter[-1]) + +relative_cams.append(torch.as_tensor(relative_cam[:3,:])) + +print(relative_cams[-1]) + diff --git a/scripts/test_data.py b/scripts/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9fed39601f76f3d6b6323d541b07b5303d4eb36c --- /dev/null +++ b/scripts/test_data.py @@ -0,0 +1,25 @@ +import pickle +import matplotlib.pyplot as plt +from PIL import Image +import io +import numpy as np +with open('sample_000000000000.data.pickle', 'rb') as f: + data = pickle.load(f) + +def imbytes2arr(b): + return np.array(Image.open(io.BytesIO(b))) + +step = data['steps'][0] +print("Instruction:", step['observation']['natural_language_instruction'].decode()) + +fig, axs = plt.subplots(1, 3, figsize=(12, 4)) +titles = ['image', 'hand_image', 'image_with_depth'] +keys = ['image', 'hand_image', 'image_with_depth'] +for ax, t, k in zip(axs, titles, keys): + img = imbytes2arr(step['observation'][k]) + ax.imshow(img) + ax.set_title(t) + ax.axis('off') +plt.tight_layout() +plt.savefig('step0_views.png', dpi=120) # 保存到文件 +print('Saved -> step0_views.png') \ No newline at end of file diff --git a/scripts/test_moe.py b/scripts/test_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..a65adfa3b7e0578a81f2ac084336cde0c41f1787 --- /dev/null +++ b/scripts/test_moe.py @@ -0,0 +1,680 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +from scipy.spatial.transform import Rotation as R +import traceback +import argparse + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + + +class SpatialVidFramePackDataset(torch.utils.data.Dataset): + """支持FramePack机制的SpatialVid数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + + print(f"🔧 找到 {len(self.scene_dirs)} 个SpatialVid场景") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + def select_dynamic_segment_framepack(self, full_latents): + """🔧 FramePack风格的动态选择条件帧和目标帧 - SpatialVid版本""" + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(max_condition_compressed, total_lens - target_frames_compressed) + + ratio = random.random() + #print('ratio:', ratio) + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 🔧 根据实际的condition_frames_compressed生成索引 + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 - SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx) # SpatialVid使用1倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + def create_pose_embeddings(self, cam_data, segment_info): + """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + # SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = segment_info['keyframe_original_idx'] + + relative_cams = [] + for idx in keyframe_original_idx: + if idx + 1 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # SpatialVid: 每隔1帧 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用零运动 + identity_cam = torch.zeros(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - SpatialVid版本""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, + 'clean_latent_4x_indices': clean_latent_4x_indices_final, + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + actual_latent_frames = full_latents.shape[1] + + # 动态选择段落 + segment_info = self.select_dynamic_segment_framepack(full_latents) + if segment_info is None: + continue + + # 创建pose embeddings - SpatialVid版本 + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) + if all_camera_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 # condition帧标记为1 + mask = mask.view(-1, 1) + + # 添加mask到camera embeddings + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], # 主要预测目标 + "clean_latents": framepack_inputs['clean_latents'], # 条件帧 + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # 🔧 直接传递带mask的camera embeddings + "camera": camera_with_mask, # 所有帧的camera embeddings(带mask) + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + "condition_frames": n, # 压缩后的帧数 + "target_frames": m, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + "dataset_name": "spatialvid", + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelMoe) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +class SpatialVidFramePackLightningModel(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder + self.add_framepack_components() + self.add_moe_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=False) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["moe","sekai_processor"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "spatialvid_framepack/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def add_moe_components(self): + """🔧 添加MoE相关组件 - 类似add_framepack_components的方式""" + if not hasattr(self.pipe.dit, 'moe_config'): + self.pipe.dit.moe_config = self.moe_config + print("✅ 添加了MoE配置到模型") + + # 为每个block动态添加MoE组件 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + unified_dim = 25 + + for i, block in enumerate(self.pipe.dit.blocks): + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + + # Sekai模态处理器 - 输出unified_dim + block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + + # NuScenes模态处理器 - 输出unified_dim + # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + + # block.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + + + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=1, + top_k=1 + ) + + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤 - SpatialVid版本""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 🔧 准备FramePack风格的输入 - 确保有batch维度 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 条件输入(处理空张量和维度) + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 🔧 索引(处理空张量) + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # 🔧 直接使用带mask的camera embeddings + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 # 10%概率丢弃camera条件 + if random.random() < camera_dropout_prob: + # 创建零camera embedding + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: # 80%概率添加噪声 + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred, moe_loss = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, # 🔧 直接传递带mask的camera embeddings + # 🔧 FramePack风格的条件输入 + modality_inputs={"sekai": cam_emb}, + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print(f'--------loss ({dataset_name})------------:', loss) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_moe_test" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + print(f"Saved SpatialVid FramePack model checkpoint: step{current_step}.ckpt") + + +def train_spatialvid_framepack(args): + """训练支持FramePack机制的SpatialVid模型""" + dataset = SpatialVidFramePackDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = SpatialVidFramePackLightningModel( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=False + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train SpatialVid FramePack Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=400) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt") + + args = parser.parse_args() + + print("🔧 开始训练SpatialVid FramePack模型:") + print(f"📁 数据集路径: {args.dataset_path}") + print(f"🎯 条件帧范围: {args.min_condition_frames}-{args.max_condition_frames}") + print(f"🎯 目标帧数: {args.target_frames}") + print("🔧 特殊优化:") + print(" - 使用WanModelFuture模型架构") + print(" - 添加FramePack多尺度输入支持") + print(" - SpatialVid特有:camera间隔为1帧") + print(" - CFG训练支持(10%概率camera dropout)") + + train_spatialvid_framepack(args) \ No newline at end of file diff --git a/scripts/train_load_test.py b/scripts/train_load_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f82665612f7afdc06cc2d2983470b69500d653 --- /dev/null +++ b/scripts/train_load_test.py @@ -0,0 +1,732 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil +import wandb +import pdb + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + while True: + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + break + except: + data_id += 1 + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".recam.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + print(f"Output: {pth_path}") + else: + print(f"File {pth_path} already exists, skipping.") + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch, condition_frames=32, target_frames=32): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".recam.pth" for i in self.path if os.path.exists(i + ".recam.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + self.steps_per_epoch = steps_per_epoch + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + def get_relative_pose(self, pose_prev, pose_curr): + """计算相对位姿:从pose_prev到pose_curr""" + pose_prev_inv = np.linalg.inv(pose_prev) + relative_pose = pose_curr @ pose_prev_inv + return relative_pose + + def __getitem__(self, index): + while True: + try: + data = {} + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) + + # 加载单个相机的数据 + path = self.path[data_id] + video_data = torch.load(path, weights_only=True, map_location="cpu") + + # 获取视频latents + full_latents = video_data['latents'] # [C, T, H, W] + total_frames = full_latents.shape[1] + + # 检查是否有足够的帧数 + required_frames = self.condition_frames + self.target_frames + if total_frames < required_frames: + continue + + # 随机选择起始位置 + max_start = total_frames - required_frames + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + + # 提取condition和target段 + condition_latents = full_latents[:, start_frame:start_frame+self.condition_frames, :, :] + target_latents = full_latents[:, start_frame+self.condition_frames:start_frame+self.condition_frames+self.target_frames, :, :] + + # 拼接latents [condition, target] - 注意:训练时condition帧在前,target帧在后 + data['latents'] = torch.cat([condition_latents, target_latents], dim=1) + + data['prompt_emb'] = video_data['prompt_emb'] + data['image_emb'] = video_data.get('image_emb', {}) + + # 加载相机轨迹数据,生成时序相对位姿 + base_path = path.rsplit('/', 2)[0] + camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") + + if not os.path.exists(camera_path): + # 如果没有相机数据,生成零向量 - 只为target帧生成 + pose_embedding = torch.zeros(self.target_frames, 12, dtype=torch.bfloat16) + else: + with open(camera_path, 'r') as file: + cam_data = json.load(file) + + # 提取相机路径(使用相同相机的不同时间点) + match = re.search(r'cam(\d+)', path) + cam_idx = int(match.group(1)) if match else 1 + + # 为target帧生成相对位姿 + relative_poses = [] + + # 计算每个target帧相对于condition最后一帧的位姿 + condition_end_frame_idx = start_frame + self.condition_frames - 1 + + # 获取reference pose(condition段的最后一帧) + if f"frame{condition_end_frame_idx}" in cam_data and f"cam{cam_idx:02d}" in cam_data[f"frame{condition_end_frame_idx}"]: + reference_matrix_str = cam_data[f"frame{condition_end_frame_idx}"][f"cam{cam_idx:02d}"] + reference_pose = self.parse_matrix(reference_matrix_str) + if reference_pose.shape == (3, 4): + reference_pose = np.vstack([reference_pose, np.array([0, 0, 0, 1.0])]) + else: + reference_pose = np.eye(4, dtype=np.float32) + + # 🔧 修复:为每个target帧计算相对位姿 + for i in range(self.target_frames): + target_frame_idx = start_frame + self.condition_frames + i + + if f"frame{target_frame_idx}" in cam_data and f"cam{cam_idx:02d}" in cam_data[f"frame{target_frame_idx}"]: + target_matrix_str = cam_data[f"frame{target_frame_idx}"][f"cam{cam_idx:02d}"] + target_pose = self.parse_matrix(target_matrix_str) + if target_pose.shape == (3, 4): + target_pose = np.vstack([target_pose, np.array([0, 0, 0, 1.0])]) + + # 🔧 修复:正确调用get_relative_pose方法 + relative_pose = self.get_relative_pose(reference_pose, target_pose) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行 + else: + # 如果没有对应帧的数据,使用单位矩阵 + relative_poses.append(torch.as_tensor(np.eye(3, 4, dtype=np.float32))) + + pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4] + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12] + + data['camera'] = pose_embedding.to(torch.bfloat16) + break + + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.path)) + + return data + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + condition_frames=10, + target_frames=5, + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) # [B, C, T, H, W], T = condition_frames + target_frames + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + target_height, target_width = 40, 70 + current_height, current_width = latents.shape[3], latents.shape[4] + + if current_height > target_height or current_width > target_width: + h_start = (current_height - target_height) // 2 + w_start = (current_width - target_width) // 2 + latents = latents[:, :, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) # [B, target_frames, 12] - 只有target帧的pose + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 修复:condition段在前,保持clean;target段在后,参与去噪训练 + cond_len = self.condition_frames + noisy_latents[:, :, :cond_len, ...] = origin_latents[:, :, :cond_len, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss (只对target段计算loss) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 🔧 修复:只对target段(后半部分)计算loss + target_noise_pred = noise_pred[:, :, cond_len:, ...] + target_training_target = training_target[:, :, cond_len:, ...] + + loss = torch.nn.functional.mse_loss( + target_noise_pred.float(), + target_training_target.float() + ) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + wandb.log({ + "train_loss": loss.item(), + "condition_frames": cond_len, + "target_frames": self.target_frames, + }) + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/models/checkpoints" + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ReCamMaster") + parser.add_argument( + "--task", + type=str, + default="train", + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/MultiCamVideo-Dataset", + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=100, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=2, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=True, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default="cloud", + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + parser.add_argument( + "--condition_frames", + type=int, + default=8, + help="Number of condition frames (kept clean).", + ) + parser.add_argument( + "--target_frames", + type=int, + default=8, + help="Number of target frames (to be denoised).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, args.metadata_file_name), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + + if args.use_swanlab: + wandb.init( + project="recam", + name="recam", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_moe.py b/scripts/train_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..3af69fdea2bdffe449e9bb0fdf3406604d0aec37 --- /dev/null +++ b/scripts/train_moe.py @@ -0,0 +1,1057 @@ +#融合nuscenes和sekai数据集的MoE训练 +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import time +import copy +import json +import numpy as np +import random +import traceback +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +import argparse +from scipy.spatial.transform import Rotation as R + +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +class MultiDatasetDynamicDataset(torch.utils.data.Dataset): + """支持FramePack机制的多数据集动态历史长度数据集 - 融合nuscenes和sekai""" + + def __init__(self, dataset_configs, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + """ + Args: + dataset_configs: 数据集配置列表,每个配置包含 { + 'name': 数据集名称, + 'paths': 数据集路径列表, + 'type': 数据集类型 ('sekai' 或 'nuscenes'), + 'weight': 采样权重 + } + """ + self.dataset_configs = dataset_configs + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 + + # 🔧 扫描所有数据集,建立统一的场景索引 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景的数据集信息 + self.dataset_weights = [] # 每个场景的采样权重 + + total_scenes = 0 + + for config in self.dataset_configs: + dataset_name = config['name'] + dataset_paths = config['paths'] if isinstance(config['paths'], list) else [config['paths']] + dataset_type = config['type'] + dataset_weight = config.get('weight', 1.0) + + print(f"🔧 扫描数据集: {dataset_name} (类型: {dataset_type})") + + dataset_scenes = [] + for dataset_path in dataset_paths: + print(f" 📁 检查路径: {dataset_path}") + if os.path.exists(dataset_path): + if dataset_type == 'nuscenes': + # NuScenes使用 base_path/scenes 结构 + scenes_path = os.path.join(dataset_path, "scenes") + print(f" 📂 扫描NuScenes scenes目录: {scenes_path}") + for item in os.listdir(scenes_path): + scene_dir = os.path.join(scenes_path, item) + if os.path.isdir(scene_dir): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + + elif dataset_type in ['sekai', 'spatialvid', 'openx']: + # Sekai、spatialvid、OpenX等数据集直接扫描根目录 + for item in os.listdir(dataset_path): + scene_dir = os.path.join(dataset_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + else: + print(f" ❌ 路径不存在: {dataset_path}") + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + total_scenes += len(dataset_scenes) + + # 统计各数据集场景数 + dataset_counts = {} + for scene_dir in self.scene_dirs: + dataset_name = self.dataset_info[scene_dir]['name'] + dataset_type = self.dataset_info[scene_dir]['type'] + key = f"{dataset_name} ({dataset_type})" + dataset_counts[key] = dataset_counts.get(key, 0) + 1 + + for dataset_key, count in dataset_counts.items(): + print(f" - {dataset_key}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 🔧 计算采样概率 + total_weight = sum(self.dataset_weights) + self.sampling_probs = [w / total_weight for w in self.dataset_weights] + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数 - NuScenes专用""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,正确处理空索引""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] # 注意维度顺序 + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def create_sekai_pose_embeddings(self, cam_data, segment_info): + """创建Sekai风格的pose embeddings""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_openx_pose_embeddings(self, cam_data, segment_info): + """🔧 创建OpenX风格的pose embeddings - 类似sekai但处理更短的序列""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose - OpenX使用4倍间隔 + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + keyframe_idx = compressed_idx * 4 + if keyframe_idx + 4 < len(cam_data_seq): + all_keyframe_indices.append(keyframe_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + if idx + 4 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用单位矩阵 + identity_cam = torch.eye(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_spatialvid_pose_embeddings(self, cam_data, segment_info): + """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + # SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = segment_info['keyframe_original_idx'] + + relative_cams = [] + for idx in keyframe_original_idx: + if idx + 1 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # SpatialVid: 每隔1帧 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用零运动 + identity_cam = torch.zeros(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_nuscenes_pose_embeddings_framepack(self, scene_info, segment_info): + """创建NuScenes风格的pose embeddings - 每帧都相对上一帧(7维)""" + keyframe_poses = scene_info['keyframe_poses'] + # 生成所有需要的关键帧索引 + start_frame = segment_info['start_frame'] + total_frames = segment_info['condition_frames'] + segment_info['target_frames'] + keyframe_indices = [] + for i in range(total_frames + 1): # +1是因为需要前后两帧 + idx = (start_frame + i) * self.time_compression_ratio + keyframe_indices.append(idx) + # 边界检查 + keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices] + + pose_vecs = [] + for i in range(total_frames): + pose_prev = keyframe_poses[keyframe_indices[i]] + pose_next = keyframe_poses[keyframe_indices[i+1]] + # 计算相对位姿 + translation = torch.tensor( + np.array(pose_next['translation']) - np.array(pose_prev['translation']), + dtype=torch.float32 + ) + relative_rotation = self.calculate_relative_rotation( + pose_next['rotation'], + pose_prev['rotation'] + ) + pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + if not pose_vecs: + return None + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 7] + return pose_sequence + + # 修改select_dynamic_segment方法 + def select_dynamic_segment(self, full_latents, dataset_type, scene_info=None): + """🔧 根据数据集类型选择不同的段落选择策略""" + # 原有的sekai方式 + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(total_lens-target_frames_compressed-1, max_condition_compressed) + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9 or total_lens <= 2*target_frames_compressed + 1: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + if dataset_type == 'spatialvid': + keyframe_original_idx.append(compressed_idx) # spatialvid直接使用compressed_idx + elif dataset_type == 'openx' or 'sekai' or "nuscenes": # 🔧 新增openx处理 + keyframe_original_idx.append(compressed_idx * 4) # openx使用4倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + # 修改create_pose_embeddings方法 + def create_pose_embeddings(self, cam_data, segment_info, dataset_type, scene_info=None): + """🔧 根据数据集类型创建pose embeddings""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.create_nuscenes_pose_embeddings_framepack(scene_info, segment_info) + elif dataset_type == 'spatialvid': # 🔧 新增spatialvid处理 + return self.create_spatialvid_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'sekai': + return self.create_sekai_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'openx': # 🔧 新增openx处理 + return self.create_openx_pose_embeddings(cam_data, segment_info) + + def __getitem__(self, index): + while True: + try: + # 根据权重随机选择场景 + scene_idx = np.random.choice(len(self.scene_dirs), p=self.sampling_probs) + scene_dir = self.scene_dirs[scene_idx] + dataset_info = self.dataset_info[scene_dir] + + dataset_name = dataset_info['name'] + dataset_type = dataset_info['type'] + + # 🔧 根据数据集类型加载数据 + scene_info = None + if dataset_type == 'nuscenes': + # NuScenes需要加载scene_info.json + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + + # NuScenes使用不同的编码文件名 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if not os.path.exists(encoded_path): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") # fallback + + encoded_data = torch.load(encoded_path, weights_only=True, map_location="cpu") + else: + # Sekai数据集 + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + encoded_data = torch.load(encoded_path, weights_only=False, map_location="cpu") + + full_latents = encoded_data['latents'] + if full_latents.shape[1] <= 10: + continue + cam_data = encoded_data.get('cam_emb', encoded_data) + + # 🔧 验证NuScenes的latent帧数 + if dataset_type == 'nuscenes' and scene_info is not None: + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + if abs(actual_latent_frames - expected_latent_frames) > 2: + print(f"⚠️ NuScenes Latent帧数不匹配,跳过此样本") + continue + + # 使用数据集特定的段落选择策略 + segment_info = self.select_dynamic_segment(full_latents, dataset_type, scene_info) + if segment_info is None: + continue + + # 创建数据集特定的pose embeddings + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info, dataset_type, scene_info) + if all_camera_embeddings is None: + continue + + # 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + # 🔧 NuScenes返回的是直接的embedding,Sekai返回的是tensor + if isinstance(all_camera_embeddings, torch.Tensor): + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + else: + # NuScenes风格,直接就是最终的embedding + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], + "clean_latents": framepack_inputs['clean_latents'], + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # Camera数据 + "camera": camera_with_mask, + + # 其他数据 + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + # 元信息 + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, + "dataset_type": dataset_type, + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) # 🔧 使用MoE版本 + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class MultiDatasetLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + # 🔧 MoE参数 + use_moe=False, + moe_config=None + ): + super().__init__() + self.use_moe = use_moe + self.moe_config = moe_config or {} + + replace_dit_model_in_manager() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加FramePack的clean_x_embedder + self.add_framepack_components() + if self.use_moe: + self.add_moe_components() + + # 🔧 添加camera编码器(wan_video_dit_moe.py已经包含MoE逻辑) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + # 🔧 简化:只添加传统camera编码器,MoE逻辑在wan_video_dit_moe.py中 + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + state_dict.pop("global_router.weight", None) + state_dict.pop("global_router.bias", None) + self.pipe.dit.load_state_dict(state_dict, strict=False) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 🔧 训练参数设置 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder", + "moe", "sekai_processor", "nuscenes_processor","openx_processor"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "multi_dataset_dynamic/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_moe_components(self): + """🔧 添加MoE相关组件 - 简化版,只为每个block添加MoE,全局processor在WanModelMoe中""" + if not hasattr(self.pipe.dit, 'moe_config'): + self.pipe.dit.moe_config = self.moe_config + print("✅ 添加了MoE配置到模型") + self.pipe.dit.top_k = self.moe_config.get("top_k", 1) + + # 为每个block添加MoE组件(modality processors已经在WanModelMoe中全局创建) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + unified_dim = self.moe_config.get("unified_dim", 30) + num_experts = self.moe_config.get("num_experts", 4) + from diffsynth.models.wan_video_dit_moe import MultiModalMoE, ModalityProcessor + + self.pipe.dit.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + self.pipe.dit.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + self.pipe.dit.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + self.pipe.dit.global_router = nn.Linear(unified_dim, num_experts) + + for i, block in enumerate(self.pipe.dit.blocks): + # 只为每个block添加MoE网络 + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, + num_experts=self.moe_config.get("num_experts", 4), + top_k=self.moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {self.moe_config.get('num_experts', 4)})") + + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 多数据集训练步骤""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + dataset_type = batch.get("dataset_type", ["sekai"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 准备输入数据 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: + latents = latents.unsqueeze(0) + + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 索引处理 + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # Camera embedding处理 + cam_emb = batch["camera"].to(self.device) + + # 🔧 根据数据集类型设置modality_inputs + if dataset_type == "sekai": + modality_inputs = {"sekai": cam_emb} + elif dataset_type == "spatialvid": # 🔧 spatialvid使用sekai processor + modality_inputs = {"sekai": cam_emb} # 注意:这里使用"sekai"键 + elif dataset_type == "nuscenes": + modality_inputs = {"nuscenes": cam_emb} + elif dataset_type == "openx": # 🔧 新增:openx使用独立的processor + modality_inputs = {"openx": cam_emb} + else: + modality_inputs = {"sekai": cam_emb} # 默认 + + camera_dropout_prob = 0.05 + if random.random() < camera_dropout_prob: + cam_emb = torch.zeros_like(cam_emb) + # 同时清空modality_inputs + for key in modality_inputs: + modality_inputs[key] = torch.zeros_like(modality_inputs[key]) + print(f"应用camera dropout for CFG training (dataset: {dataset_name}, type: {dataset_type})") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + noise_pred, specialization_loss = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, + modality_inputs=modality_inputs, # 🔧 传递多模态输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + # 🔧 计算总loss = 重建loss + MoE专业化loss + reconstruction_loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + reconstruction_loss = reconstruction_loss * self.pipe.scheduler.training_weight(timestep) + + # 🔧 添加MoE专业化loss(交叉熵损失) + specialization_loss_weight = self.moe_config.get("moe_loss_weight", 0.1) + total_loss = reconstruction_loss + specialization_loss_weight * specialization_loss + + print(f'\n loss info (step {self.global_step}):') + print(f' - diff loss: {reconstruction_loss.item():.6f}') + print(f' - MoE specification loss: {specialization_loss.item():.6f}') + print(f' - Expert loss weight: {specialization_loss_weight}') + print(f' - Total Loss: {total_loss.item():.6f}') + + # 🔧 显示预期的专家映射 + modality_to_expert = { + "sekai": 0, + "nuscenes": 1, + "openx": 2 + } + expected_expert = modality_to_expert.get(dataset_type, 0) + print(f' - current modality: {dataset_type} -> expected expert: {expected_expert}') + + return total_loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + t = time.strftime("%Y%m%d-%H%M%S") # 20250923-153047 + + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_nus_moe_from_{t}.ckpt")) + print(f"Saved MoE model checkpoint: step{current_step}.ckpt") + +def train_multi_dataset(args): + """训练支持多数据集MoE的模型""" + + # 🔧 数据集配置 + dataset_configs = [ + { + 'name': 'sekai-drone', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-drone'], + 'type': 'sekai', + 'weight': 0.1 + }, + { + 'name': 'sekai-walking', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-walking'], + 'type': 'sekai', + 'weight': 0.1 + }, + { + 'name': 'spatialvid', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/spatialvid'], + 'type': 'spatialvid', + 'weight': 0.1 + }, + { + 'name': 'nuscenes', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic'], + 'type': 'nuscenes', + 'weight': 10.0 + }, + { + 'name': 'openx-fractal', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded'], + 'type': 'openx', + 'weight': 0.1 + } + ] + + dataset = MultiDatasetDynamicDataset( + dataset_configs, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 🔧 MoE配置 + moe_config = { + "unified_dim": args.unified_dim, # 新增 + "num_experts": args.moe_num_experts, + "top_k": args.moe_top_k, + "moe_loss_weight": args.moe_loss_weight, + "sekai_input_dim": 13, + "nuscenes_input_dim": 8, + "openx_input_dim": 13 + } + + model = MultiDatasetLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + use_moe=True, # 总是使用MoE + moe_config=moe_config + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train Multi-Dataset FramePack with MoE") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=8000) + parser.add_argument("--max_epochs", type=int, default=100000) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", default=False) + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt") + + # 🔧 MoE参数 + parser.add_argument("--unified_dim", type=int, default=25, help="统一的中间维度") + parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_loss_weight", type=float, default=0.1, help="MoE损失权重") + + args = parser.parse_args() + + print("🔧 多数据集MoE训练配置:") + print(f" - 使用wan_video_dit_moe.py作为模型") + print(f" - 统一维度: {args.unified_dim}") + print(f" - 专家数量: {args.moe_num_experts}") + print(f" - Top-K: {args.moe_top_k}") + print(f" - MoE损失权重: {args.moe_loss_weight}") + print(" - 数据集:") + print(" - sekai-game-drone (sekai模态)") + print(" - sekai-game-walking (sekai模态)") + print(" - spatialvid (使用sekai模态处理器)") + print(" - openx-fractal (使用sekai模态处理器)") + print(f" - nuscenes (nuscenes模态)") + + train_multi_dataset(args) \ No newline at end of file diff --git a/scripts/train_moe_spatialvid.py b/scripts/train_moe_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..af04987059781ca36ed3097f53727cb60201a16b --- /dev/null +++ b/scripts/train_moe_spatialvid.py @@ -0,0 +1,1247 @@ +#融合nuscenes和sekai数据集的MoE训练 +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +import json +import numpy as np +import random +import traceback +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +import argparse +from scipy.spatial.transform import Rotation as R + +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +class MultiDatasetDynamicDataset(torch.utils.data.Dataset): + """支持FramePack机制的多数据集动态历史长度数据集 - 融合nuscenes和sekai""" + + def __init__(self, dataset_configs, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + """ + Args: + dataset_configs: 数据集配置列表,每个配置包含 { + 'name': 数据集名称, + 'paths': 数据集路径列表, + 'type': 数据集类型 ('sekai' 或 'nuscenes'), + 'weight': 采样权重 + } + """ + self.dataset_configs = dataset_configs + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 + + # 🔧 扫描所有数据集,建立统一的场景索引 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景的数据集信息 + self.dataset_weights = [] # 每个场景的采样权重 + + total_scenes = 0 + + for config in self.dataset_configs: + dataset_name = config['name'] + dataset_paths = config['paths'] if isinstance(config['paths'], list) else [config['paths']] + dataset_type = config['type'] + dataset_weight = config.get('weight', 1.0) + + print(f"🔧 扫描数据集: {dataset_name} (类型: {dataset_type})") + + dataset_scenes = [] + for dataset_path in dataset_paths: + print(f" 📁 检查路径: {dataset_path}") + if os.path.exists(dataset_path): + if dataset_type == 'nuscenes': + # NuScenes使用 base_path/scenes 结构 + scenes_path = os.path.join(dataset_path, "scenes") + print(f" 📂 扫描NuScenes scenes目录: {scenes_path}") + for item in os.listdir(scenes_path): + scene_dir = os.path.join(scenes_path, item) + if os.path.isdir(scene_dir): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + + elif dataset_type in ['sekai', 'spatialvid', 'openx']: # 🔧 添加openx类型 + # Sekai、spatialvid、OpenX等数据集直接扫描根目录 + for item in os.listdir(dataset_path): + scene_dir = os.path.join(dataset_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + else: + print(f" ❌ 路径不存在: {dataset_path}") + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + total_scenes += len(dataset_scenes) + + # 统计各数据集场景数 + dataset_counts = {} + for scene_dir in self.scene_dirs: + dataset_name = self.dataset_info[scene_dir]['name'] + dataset_type = self.dataset_info[scene_dir]['type'] + key = f"{dataset_name} ({dataset_type})" + dataset_counts[key] = dataset_counts.get(key, 0) + 1 + + for dataset_key, count in dataset_counts.items(): + print(f" - {dataset_key}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 🔧 计算采样概率 + total_weight = sum(self.dataset_weights) + self.sampling_probs = [w / total_weight for w in self.dataset_weights] + + def select_dynamic_segment_nuscenes(self, scene_info): + """🔧 NuScenes专用的FramePack风格段落选择""" + keyframe_indices = scene_info['keyframe_indices'] # 原始帧索引 + total_frames = scene_info['total_frames'] # 原始总帧数 + + if len(keyframe_indices) < 2: + return None + + # 计算压缩后的帧数 + compressed_total_frames = total_frames // self.time_compression_ratio + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in keyframe_indices] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + # FramePack风格的采样策略 + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if compressed_total_frames < min_required_frames: + return None + + start_frame_compressed = random.randint(0, compressed_total_frames - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 NuScenes特有:查找关键帧索引 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame_compressed <= idx < condition_end_compressed] + + target_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if condition_end_compressed <= idx < target_end_compressed] + + if not condition_keyframes_compressed: + return None + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = max(condition_keyframes_compressed) + + # 找到对应的原始关键帧索引用于pose查找 + reference_keyframe_original_idx = None + for i, compressed_idx in enumerate(compressed_keyframe_indices): + if compressed_idx == reference_keyframe_compressed: + reference_keyframe_original_idx = i + break + + if reference_keyframe_original_idx is None: + return None + + # 找到目标段对应的原始关键帧索引 + target_keyframes_original_indices = [] + for compressed_idx in target_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + target_keyframes_original_indices.append(i) + break + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + + # 🔧 NuScenes特有数据 + 'reference_keyframe_idx': reference_keyframe_original_idx, + 'target_keyframe_indices': target_keyframes_original_indices, + } + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数 - NuScenes专用""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,正确处理空索引""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] # 注意维度顺序 + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def create_sekai_pose_embeddings(self, cam_data, segment_info): + """创建Sekai风格的pose embeddings""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_openx_pose_embeddings(self, cam_data, segment_info): + """🔧 创建OpenX风格的pose embeddings - 类似sekai但处理更短的序列""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose - OpenX使用4倍间隔 + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + keyframe_idx = compressed_idx * 4 + if keyframe_idx + 4 < len(cam_data_seq): + all_keyframe_indices.append(keyframe_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + if idx + 4 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用单位矩阵 + identity_cam = torch.eye(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_spatialvid_pose_embeddings(self, cam_data, segment_info): + """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + # SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = segment_info['keyframe_original_idx'] + + relative_cams = [] + for idx in keyframe_original_idx: + if idx + 1 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # SpatialVid: 每隔1帧 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用零运动 + identity_cam = torch.zeros(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_nuscenes_pose_embeddings_framepack(self, scene_info, segment_info): + """创建NuScenes风格的pose embeddings - FramePack版本(简化版本,直接7维)""" + keyframe_poses = scene_info['keyframe_poses'] + reference_keyframe_idx = segment_info['reference_keyframe_idx'] + target_keyframe_indices = segment_info['target_keyframe_indices'] + + if reference_keyframe_idx >= len(keyframe_poses): + return None + + reference_pose = keyframe_poses[reference_keyframe_idx] + + # 为所有帧(condition + target)创建pose embeddings + start_frame = segment_info['start_frame'] + condition_end_compressed = start_frame + segment_info['condition_frames'] + target_end_compressed = condition_end_compressed + segment_info['target_frames'] + + # 压缩后的关键帧索引 + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in scene_info['keyframe_indices']] + + # 找到condition段的关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame <= idx < condition_end_compressed] + + # 找到对应的原始关键帧索引 + condition_keyframes_original_indices = [] + for compressed_idx in condition_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + condition_keyframes_original_indices.append(i) + break + + pose_vecs = [] + + # 为condition帧计算pose + for i in range(segment_info['condition_frames']): + if not condition_keyframes_original_indices: + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + # 为condition帧分配pose + if len(condition_keyframes_original_indices) == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + if segment_info['condition_frames'] == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + interp_ratio = i / (segment_info['condition_frames'] - 1) + interp_idx = int(interp_ratio * (len(condition_keyframes_original_indices) - 1)) + keyframe_idx = condition_keyframes_original_indices[interp_idx] + + if keyframe_idx >= len(keyframe_poses): + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + condition_pose = keyframe_poses[keyframe_idx] + + translation = torch.tensor( + np.array(condition_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + condition_pose['rotation'], + reference_pose['rotation'] + ) + + rotation = relative_rotation + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + # 为target帧计算pose + if not target_keyframe_indices: + for i in range(segment_info['target_frames']): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + pose_vecs.append(pose_vec) + else: + for i in range(segment_info['target_frames']): + if len(target_keyframe_indices) == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + if segment_info['target_frames'] == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + interp_ratio = i / (segment_info['target_frames'] - 1) + interp_idx = int(interp_ratio * (len(target_keyframe_indices) - 1)) + target_keyframe_idx = target_keyframe_indices[interp_idx] + + if target_keyframe_idx >= len(keyframe_poses): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + else: + target_pose = keyframe_poses[target_keyframe_idx] + + relative_translation = torch.tensor( + np.array(target_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + target_pose['rotation'], + reference_pose['rotation'] + ) + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([relative_translation, relative_rotation], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + if not pose_vecs: + return None + + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 7] + + return pose_sequence + + # 修改select_dynamic_segment方法 + def select_dynamic_segment(self, full_latents, dataset_type, scene_info=None): + """🔧 根据数据集类型选择不同的段落选择策略""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.select_dynamic_segment_nuscenes(scene_info) + else: + # 原有的sekai方式 + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(total_lens-target_frames_compressed-1, max_condition_compressed) + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9 or total_lens <= 2*target_frames_compressed + 1: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + if dataset_type == 'spatialvid': + keyframe_original_idx.append(compressed_idx) # spatialvid直接使用compressed_idx + elif dataset_type == 'openx' or 'sekai': # 🔧 新增openx处理 + keyframe_original_idx.append(compressed_idx * 4) # openx使用4倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + # 修改create_pose_embeddings方法 + def create_pose_embeddings(self, cam_data, segment_info, dataset_type, scene_info=None): + """🔧 根据数据集类型创建pose embeddings""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.create_nuscenes_pose_embeddings_framepack(scene_info, segment_info) + elif dataset_type == 'spatialvid': # 🔧 新增spatialvid处理 + return self.create_spatialvid_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'sekai': + return self.create_sekai_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'openx': # 🔧 新增openx处理 + return self.create_openx_pose_embeddings(cam_data, segment_info) + + def __getitem__(self, index): + while True: + try: + # 根据权重随机选择场景 + scene_idx = np.random.choice(len(self.scene_dirs), p=self.sampling_probs) + scene_dir = self.scene_dirs[scene_idx] + dataset_info = self.dataset_info[scene_dir] + + dataset_name = dataset_info['name'] + dataset_type = dataset_info['type'] + + # 🔧 根据数据集类型加载数据 + scene_info = None + if dataset_type == 'nuscenes': + # NuScenes需要加载scene_info.json + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + + # NuScenes使用不同的编码文件名 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if not os.path.exists(encoded_path): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") # fallback + + encoded_data = torch.load(encoded_path, weights_only=True, map_location="cpu") + else: + # Sekai数据集 + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + encoded_data = torch.load(encoded_path, weights_only=False, map_location="cpu") + + full_latents = encoded_data['latents'] + if full_latents.shape[1] <= 10: + continue + cam_data = encoded_data.get('cam_emb', encoded_data) + + # 🔧 验证NuScenes的latent帧数 + if dataset_type == 'nuscenes' and scene_info is not None: + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + if abs(actual_latent_frames - expected_latent_frames) > 2: + print(f"⚠️ NuScenes Latent帧数不匹配,跳过此样本") + continue + + # 使用数据集特定的段落选择策略 + segment_info = self.select_dynamic_segment(full_latents, dataset_type, scene_info) + if segment_info is None: + continue + + # 创建数据集特定的pose embeddings + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info, dataset_type, scene_info) + if all_camera_embeddings is None: + continue + + # 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + # 🔧 NuScenes返回的是直接的embedding,Sekai返回的是tensor + if isinstance(all_camera_embeddings, torch.Tensor): + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + else: + # NuScenes风格,直接就是最终的embedding + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], + "clean_latents": framepack_inputs['clean_latents'], + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # Camera数据 + "camera": camera_with_mask, + + # 其他数据 + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + # 元信息 + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, + "dataset_type": dataset_type, + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) # 🔧 使用MoE版本 + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class MultiDatasetLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + # 🔧 MoE参数 + use_moe=False, + moe_config=None + ): + super().__init__() + self.use_moe = use_moe + self.moe_config = moe_config or {} + + replace_dit_model_in_manager() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加FramePack的clean_x_embedder + self.add_framepack_components() + if self.use_moe: + self.add_moe_components() + + # 🔧 添加camera编码器(wan_video_dit_moe.py已经包含MoE逻辑) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + # 🔧 简化:只添加传统camera编码器,MoE逻辑在wan_video_dit_moe.py中 + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=False) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 🔧 训练参数设置 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["projector", "self_attn", "clean_x_embedder", + "moe", "sekai_processor"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "multi_dataset_dynamic/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_moe_components(self): + """🔧 添加MoE相关组件 - 类似add_framepack_components的方式""" + if not hasattr(self.pipe.dit, 'moe_config'): + self.pipe.dit.moe_config = self.moe_config + print("✅ 添加了MoE配置到模型") + + # 为每个block动态添加MoE组件 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + unified_dim = self.moe_config.get("unified_dim", 30) + + for i, block in enumerate(self.pipe.dit.blocks): + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + + # Sekai模态处理器 - 输出unified_dim + block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + + # NuScenes模态处理器 - 输出unified_dim + # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + + # block.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + + + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=self.moe_config.get("num_experts", 4), + top_k=self.moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {self.moe_config.get('num_experts', 4)})") + + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 多数据集训练步骤""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + dataset_type = batch.get("dataset_type", ["sekai"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 准备输入数据 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: + latents = latents.unsqueeze(0) + + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 索引处理 + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # Camera embedding处理 + cam_emb = batch["camera"].to(self.device) + + # 🔧 根据数据集类型设置modality_inputs + if dataset_type == "sekai": + modality_inputs = {"sekai": cam_emb} + elif dataset_type == "spatialvid": # 🔧 spatialvid使用sekai processor + modality_inputs = {"sekai": cam_emb} # 注意:这里使用"sekai"键 + elif dataset_type == "nuscenes": + modality_inputs = {"nuscenes": cam_emb} + elif dataset_type == "openx": # 🔧 新增:openx使用独立的processor + modality_inputs = {"openx": cam_emb} + else: + modality_inputs = {"sekai": cam_emb} # 默认 + + camera_dropout_prob = 0.05 + if random.random() < camera_dropout_prob: + cam_emb = torch.zeros_like(cam_emb) + # 同时清空modality_inputs + for key in modality_inputs: + modality_inputs[key] = torch.zeros_like(modality_inputs[key]) + print(f"应用camera dropout for CFG training (dataset: {dataset_name}, type: {dataset_type})") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 Forward调用 - 传递modality_inputs + noise_pred, moe_loss = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, + modality_inputs=modality_inputs, # 🔧 传递多模态输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + print(f'--------loss ({dataset_name}-{dataset_type})------------:', loss) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + print(f"Saved MoE model checkpoint: step{current_step}.ckpt") + +def train_multi_dataset(args): + """训练支持多数据集MoE的模型""" + + # 🔧 数据集配置 + dataset_configs = [ + # { + # 'name': 'sekai-drone', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-drone'], + # 'type': 'sekai', + # 'weight': 1.0 + # }, + # { + # 'name': 'sekai-walking', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-walking'], + # 'type': 'sekai', + # 'weight': 1.0 + # }, + { + 'name': 'spatialvid', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/spatialvid'], + 'type': 'spatialvid', + 'weight': 1.0 + }, + # { + # 'name': 'nuscenes', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic'], + # 'type': 'nuscenes', + # 'weight': 4.0 + # }, + # { + # 'name': 'openx-fractal', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded'], + # 'type': 'openx', + # 'weight': 1.0 + # } + ] + + dataset = MultiDatasetDynamicDataset( + dataset_configs, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 🔧 MoE配置 + moe_config = { + "unified_dim": args.unified_dim, # 新增 + "num_experts": args.moe_num_experts, + "top_k": args.moe_top_k, + "moe_loss_weight": args.moe_loss_weight, + "sekai_input_dim": 13, + "nuscenes_input_dim": 8, + "openx_input_dim": 13 + } + + model = MultiDatasetLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + use_moe=True, # 总是使用MoE + moe_config=moe_config + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train Multi-Dataset FramePack with MoE") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=2000) + parser.add_argument("--max_epochs", type=int, default=100000) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", default=False) + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt") + + # 🔧 MoE参数 + parser.add_argument("--unified_dim", type=int, default=25, help="统一的中间维度") + parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_loss_weight", type=float, default=0.00, help="MoE损失权重") + + args = parser.parse_args() + + print("🔧 多数据集MoE训练配置:") + print(f" - 使用wan_video_dit_moe.py作为模型") + print(f" - 统一维度: {args.unified_dim}") + print(f" - 专家数量: {args.moe_num_experts}") + print(f" - Top-K: {args.moe_top_k}") + print(f" - MoE损失权重: {args.moe_loss_weight}") + print(" - 数据集:") + print(" - sekai-game-drone (sekai模态)") + print(" - sekai-game-walking (sekai模态)") + print(" - spatialvid (使用sekai模态处理器)") + print(" - openx-fractal (使用sekai模态处理器)") + print(f" - nuscenes (nuscenes模态)") + + train_multi_dataset(args) \ No newline at end of file diff --git a/scripts/train_moe_test.py b/scripts/train_moe_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d093a4cf4a2fe0bde39a03e0e6ade1f073ce76fe --- /dev/null +++ b/scripts/train_moe_test.py @@ -0,0 +1,1258 @@ +#融合nuscenes和sekai数据集的MoE训练 +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +import json +import numpy as np +import random +import traceback +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +import argparse +from scipy.spatial.transform import Rotation as R + +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +class MultiDatasetDynamicDataset(torch.utils.data.Dataset): + """支持FramePack机制的多数据集动态历史长度数据集 - 融合nuscenes和sekai""" + + def __init__(self, dataset_configs, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + """ + Args: + dataset_configs: 数据集配置列表,每个配置包含 { + 'name': 数据集名称, + 'paths': 数据集路径列表, + 'type': 数据集类型 ('sekai' 或 'nuscenes'), + 'weight': 采样权重 + } + """ + self.dataset_configs = dataset_configs + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 + + # 🔧 扫描所有数据集,建立统一的场景索引 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景的数据集信息 + self.dataset_weights = [] # 每个场景的采样权重 + + total_scenes = 0 + + for config in self.dataset_configs: + dataset_name = config['name'] + dataset_paths = config['paths'] if isinstance(config['paths'], list) else [config['paths']] + dataset_type = config['type'] + dataset_weight = config.get('weight', 1.0) + + print(f"🔧 扫描数据集: {dataset_name} (类型: {dataset_type})") + + dataset_scenes = [] + for dataset_path in dataset_paths: + print(f" 📁 检查路径: {dataset_path}") + if os.path.exists(dataset_path): + if dataset_type == 'nuscenes': + # NuScenes使用 base_path/scenes 结构 + scenes_path = os.path.join(dataset_path, "scenes") + print(f" 📂 扫描NuScenes scenes目录: {scenes_path}") + for item in os.listdir(scenes_path): + scene_dir = os.path.join(scenes_path, item) + if os.path.isdir(scene_dir): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + + elif dataset_type == 'sekai': + # Sekai等其他数据集直接扫描根目录 + for item in os.listdir(dataset_path): + scene_dir = os.path.join(dataset_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + + elif dataset_type in ['sekai', 'spatialvid', 'openx']: # 🔧 添加openx类型 + # Sekai、spatialvid、OpenX等数据集直接扫描根目录 + for item in os.listdir(dataset_path): + scene_dir = os.path.join(dataset_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + else: + print(f" ❌ 路径不存在: {dataset_path}") + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + total_scenes += len(dataset_scenes) + + # 统计各数据集场景数 + dataset_counts = {} + for scene_dir in self.scene_dirs: + dataset_name = self.dataset_info[scene_dir]['name'] + dataset_type = self.dataset_info[scene_dir]['type'] + key = f"{dataset_name} ({dataset_type})" + dataset_counts[key] = dataset_counts.get(key, 0) + 1 + + for dataset_key, count in dataset_counts.items(): + print(f" - {dataset_key}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 🔧 计算采样概率 + total_weight = sum(self.dataset_weights) + self.sampling_probs = [w / total_weight for w in self.dataset_weights] + + def select_dynamic_segment_nuscenes(self, scene_info): + """🔧 NuScenes专用的FramePack风格段落选择""" + keyframe_indices = scene_info['keyframe_indices'] # 原始帧索引 + total_frames = scene_info['total_frames'] # 原始总帧数 + + if len(keyframe_indices) < 2: + return None + + # 计算压缩后的帧数 + compressed_total_frames = total_frames // self.time_compression_ratio + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in keyframe_indices] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + # FramePack风格的采样策略 + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if compressed_total_frames < min_required_frames: + return None + + start_frame_compressed = random.randint(0, compressed_total_frames - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 NuScenes特有:查找关键帧索引 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame_compressed <= idx < condition_end_compressed] + + target_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if condition_end_compressed <= idx < target_end_compressed] + + if not condition_keyframes_compressed: + return None + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = max(condition_keyframes_compressed) + + # 找到对应的原始关键帧索引用于pose查找 + reference_keyframe_original_idx = None + for i, compressed_idx in enumerate(compressed_keyframe_indices): + if compressed_idx == reference_keyframe_compressed: + reference_keyframe_original_idx = i + break + + if reference_keyframe_original_idx is None: + return None + + # 找到目标段对应的原始关键帧索引 + target_keyframes_original_indices = [] + for compressed_idx in target_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + target_keyframes_original_indices.append(i) + break + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + + # 🔧 NuScenes特有数据 + 'reference_keyframe_idx': reference_keyframe_original_idx, + 'target_keyframe_indices': target_keyframes_original_indices, + } + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数 - NuScenes专用""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,正确处理空索引""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] # 注意维度顺序 + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def create_sekai_pose_embeddings(self, cam_data, segment_info): + """创建Sekai风格的pose embeddings""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_openx_pose_embeddings(self, cam_data, segment_info): + """🔧 创建OpenX风格的pose embeddings - 类似sekai但处理更短的序列""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose - OpenX使用4倍间隔 + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + keyframe_idx = compressed_idx * 4 + if keyframe_idx + 4 < len(cam_data_seq): + all_keyframe_indices.append(keyframe_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + if idx + 4 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用单位矩阵 + identity_cam = torch.eye(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_spatialvid_pose_embeddings(self, cam_data, segment_info): + """🔧 创建Spatialvid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose - spatialvid特有:每隔1帧而不是4帧 + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + # 🔧 spatialvid关键差异:camera每隔4帧有一个,但索引递增1 + all_keyframe_indices.append(compressed_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + # 🔧 spatialvid关键差异:current和next是+1而不是+4 + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # 这里是+1,不是+4 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_nuscenes_pose_embeddings_framepack(self, scene_info, segment_info): + """创建NuScenes风格的pose embeddings - FramePack版本(简化版本,直接7维)""" + keyframe_poses = scene_info['keyframe_poses'] + reference_keyframe_idx = segment_info['reference_keyframe_idx'] + target_keyframe_indices = segment_info['target_keyframe_indices'] + + if reference_keyframe_idx >= len(keyframe_poses): + return None + + reference_pose = keyframe_poses[reference_keyframe_idx] + + # 为所有帧(condition + target)创建pose embeddings + start_frame = segment_info['start_frame'] + condition_end_compressed = start_frame + segment_info['condition_frames'] + target_end_compressed = condition_end_compressed + segment_info['target_frames'] + + # 压缩后的关键帧索引 + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in scene_info['keyframe_indices']] + + # 找到condition段的关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame <= idx < condition_end_compressed] + + # 找到对应的原始关键帧索引 + condition_keyframes_original_indices = [] + for compressed_idx in condition_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + condition_keyframes_original_indices.append(i) + break + + pose_vecs = [] + + # 为condition帧计算pose + for i in range(segment_info['condition_frames']): + if not condition_keyframes_original_indices: + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + # 为condition帧分配pose + if len(condition_keyframes_original_indices) == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + if segment_info['condition_frames'] == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + interp_ratio = i / (segment_info['condition_frames'] - 1) + interp_idx = int(interp_ratio * (len(condition_keyframes_original_indices) - 1)) + keyframe_idx = condition_keyframes_original_indices[interp_idx] + + if keyframe_idx >= len(keyframe_poses): + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + condition_pose = keyframe_poses[keyframe_idx] + + translation = torch.tensor( + np.array(condition_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + condition_pose['rotation'], + reference_pose['rotation'] + ) + + rotation = relative_rotation + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + # 为target帧计算pose + if not target_keyframe_indices: + for i in range(segment_info['target_frames']): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + pose_vecs.append(pose_vec) + else: + for i in range(segment_info['target_frames']): + if len(target_keyframe_indices) == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + if segment_info['target_frames'] == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + interp_ratio = i / (segment_info['target_frames'] - 1) + interp_idx = int(interp_ratio * (len(target_keyframe_indices) - 1)) + target_keyframe_idx = target_keyframe_indices[interp_idx] + + if target_keyframe_idx >= len(keyframe_poses): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + else: + target_pose = keyframe_poses[target_keyframe_idx] + + relative_translation = torch.tensor( + np.array(target_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + target_pose['rotation'], + reference_pose['rotation'] + ) + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([relative_translation, relative_rotation], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + if not pose_vecs: + return None + + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 7] + + return pose_sequence + + # 修改select_dynamic_segment方法 + def select_dynamic_segment(self, full_latents, dataset_type, scene_info=None): + """🔧 根据数据集类型选择不同的段落选择策略""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.select_dynamic_segment_nuscenes(scene_info) + else: + # 原有的sekai方式 + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(total_lens-target_frames_compressed-1, max_condition_compressed) + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9 or total_lens <= 2*target_frames_compressed + 1: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + if dataset_type == 'spatialvid': + keyframe_original_idx.append(compressed_idx) # spatialvid直接使用compressed_idx + elif dataset_type == 'openx' or 'sekai': # 🔧 新增openx处理 + keyframe_original_idx.append(compressed_idx * 4) # openx使用4倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + # 修改create_pose_embeddings方法 + def create_pose_embeddings(self, cam_data, segment_info, dataset_type, scene_info=None): + """🔧 根据数据集类型创建pose embeddings""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.create_nuscenes_pose_embeddings_framepack(scene_info, segment_info) + elif dataset_type == 'spatialvid': # 🔧 新增spatialvid处理 + return self.create_spatialvid_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'sekai': + return self.create_sekai_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'openx': # 🔧 新增openx处理 + return self.create_openx_pose_embeddings(cam_data, segment_info) + + def __getitem__(self, index): + while True: + try: + # 根据权重随机选择场景 + scene_idx = np.random.choice(len(self.scene_dirs), p=self.sampling_probs) + scene_dir = self.scene_dirs[scene_idx] + dataset_info = self.dataset_info[scene_dir] + + dataset_name = dataset_info['name'] + dataset_type = dataset_info['type'] + + # 🔧 根据数据集类型加载数据 + scene_info = None + if dataset_type == 'nuscenes': + # NuScenes需要加载scene_info.json + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + + # NuScenes使用不同的编码文件名 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if not os.path.exists(encoded_path): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") # fallback + + encoded_data = torch.load(encoded_path, weights_only=True, map_location="cpu") + else: + # Sekai数据集 + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + encoded_data = torch.load(encoded_path, weights_only=False, map_location="cpu") + + full_latents = encoded_data['latents'] + if full_latents.shape[1] <= 10: + continue + cam_data = encoded_data.get('cam_emb', encoded_data) + + # 🔧 验证NuScenes的latent帧数 + if dataset_type == 'nuscenes' and scene_info is not None: + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + if abs(actual_latent_frames - expected_latent_frames) > 2: + print(f"⚠️ NuScenes Latent帧数不匹配,跳过此样本") + continue + + # 使用数据集特定的段落选择策略 + segment_info = self.select_dynamic_segment(full_latents, dataset_type, scene_info) + if segment_info is None: + continue + + # 创建数据集特定的pose embeddings + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info, dataset_type, scene_info) + if all_camera_embeddings is None: + continue + + # 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + # 🔧 NuScenes返回的是直接的embedding,Sekai返回的是tensor + if isinstance(all_camera_embeddings, torch.Tensor): + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + else: + # NuScenes风格,直接就是最终的embedding + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], + "clean_latents": framepack_inputs['clean_latents'], + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # Camera数据 + "camera": camera_with_mask, + + # 其他数据 + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + # 元信息 + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, + "dataset_type": dataset_type, + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) # 🔧 使用MoE版本 + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class MultiDatasetLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + # 🔧 MoE参数 + use_moe=False, + moe_config=None + ): + super().__init__() + self.use_moe = use_moe + self.moe_config = moe_config or {} + + replace_dit_model_in_manager() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加FramePack的clean_x_embedder + self.add_framepack_components() + if self.use_moe: + self.add_moe_components() + + # 🔧 添加camera编码器(wan_video_dit_moe.py已经包含MoE逻辑) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + # 🔧 简化:只添加传统camera编码器,MoE逻辑在wan_video_dit_moe.py中 + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=False) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 🔧 训练参数设置 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in [ + "moe", "sekai_processor"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "multi_dataset_dynamic/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_moe_components(self): + """🔧 添加MoE相关组件 - 类似add_framepack_components的方式""" + if not hasattr(self.pipe.dit, 'moe_config'): + self.pipe.dit.moe_config = self.moe_config + print("✅ 添加了MoE配置到模型") + + # 为每个block动态添加MoE组件 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + unified_dim = self.moe_config.get("unified_dim", 30) + + for i, block in enumerate(self.pipe.dit.blocks): + from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE + + # Sekai模态处理器 - 输出unified_dim + block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + + # NuScenes模态处理器 - 输出unified_dim + # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + + # block.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + + + # MoE网络 - 输入unified_dim,输出dim + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, # 输出维度匹配transformer block的dim + num_experts=self.moe_config.get("num_experts", 4), + top_k=self.moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {self.moe_config.get('num_experts', 4)})") + + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 多数据集训练步骤""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + dataset_type = batch.get("dataset_type", ["sekai"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 准备输入数据 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: + latents = latents.unsqueeze(0) + + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 索引处理 + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # Camera embedding处理 + cam_emb = batch["camera"].to(self.device) + + # 🔧 根据数据集类型设置modality_inputs + if dataset_type == "sekai": + modality_inputs = {"sekai": cam_emb} + elif dataset_type == "spatialvid": # 🔧 spatialvid使用sekai processor + modality_inputs = {"sekai": cam_emb} # 注意:这里使用"sekai"键 + elif dataset_type == "nuscenes": + modality_inputs = {"nuscenes": cam_emb} + elif dataset_type == "openx": # 🔧 新增:openx使用独立的processor + modality_inputs = {"openx": cam_emb} + else: + modality_inputs = {"sekai": cam_emb} # 默认 + + camera_dropout_prob = 0.05 + if random.random() < camera_dropout_prob: + cam_emb = torch.zeros_like(cam_emb) + # 同时清空modality_inputs + for key in modality_inputs: + modality_inputs[key] = torch.zeros_like(modality_inputs[key]) + print(f"应用camera dropout for CFG training (dataset: {dataset_name}, type: {dataset_type})") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 Forward调用 - 传递modality_inputs + noise_pred, moe_loss = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, + modality_inputs=modality_inputs, # 🔧 传递多模态输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + print(f'--------loss ({dataset_name}-{dataset_type})------------:', loss) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_moe.ckpt")) + print(f"Saved MoE model checkpoint: step{current_step}_moe.ckpt") + +def train_multi_dataset(args): + """训练支持多数据集MoE的模型""" + + # 🔧 数据集配置 + dataset_configs = [ + { + 'name': 'sekai-drone', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-drone'], + 'type': 'sekai', + 'weight': 1.0 + }, + { + 'name': 'sekai-walking', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-walking'], + 'type': 'sekai', + 'weight': 1.0 + }, + # { + # 'name': 'spatialvid', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/spatialvid'], + # 'type': 'spatialvid', + # 'weight': 1.0 + # }, + # { + # 'name': 'nuscenes', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic'], + # 'type': 'nuscenes', + # 'weight': 4.0 + # }, + # { + # 'name': 'openx-fractal', + # 'paths': ['/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded'], + # 'type': 'openx', + # 'weight': 1.0 + # } + ] + + dataset = MultiDatasetDynamicDataset( + dataset_configs, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 🔧 MoE配置 + moe_config = { + "unified_dim": args.unified_dim, # 新增 + "num_experts": args.moe_num_experts, + "top_k": args.moe_top_k, + "moe_loss_weight": args.moe_loss_weight, + "sekai_input_dim": 13, + "nuscenes_input_dim": 8, + "openx_input_dim": 13 + } + + model = MultiDatasetLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + use_moe=True, # 总是使用MoE + moe_config=moe_config + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train Multi-Dataset FramePack with MoE") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=2000) + parser.add_argument("--max_epochs", type=int, default=100000) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", default=False) + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1500_moe.ckpt") + + # 🔧 MoE参数 + parser.add_argument("--unified_dim", type=int, default=25, help="统一的中间维度") + parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_loss_weight", type=float, default=0.00, help="MoE损失权重") + + args = parser.parse_args() + + print("🔧 多数据集MoE训练配置:") + print(f" - 使用wan_video_dit_moe.py作为模型") + print(f" - 统一维度: {args.unified_dim}") + print(f" - 专家数量: {args.moe_num_experts}") + print(f" - Top-K: {args.moe_top_k}") + print(f" - MoE损失权重: {args.moe_loss_weight}") + print(" - 数据集:") + print(" - sekai-game-drone (sekai模态)") + print(" - sekai-game-walking (sekai模态)") + print(" - spatialvid (使用sekai模态处理器)") + print(" - openx-fractal (使用sekai模态处理器)") + print(f" - nuscenes (nuscenes模态)") + + train_multi_dataset(args) \ No newline at end of file diff --git a/scripts/train_nus.py b/scripts/train_nus.py new file mode 100644 index 0000000000000000000000000000000000000000..7d78ad0df35a8c21e7ef066ae040c1835bdd9312 --- /dev/null +++ b/scripts/train_nus.py @@ -0,0 +1,811 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil +import wandb +import pdb +import matplotlib.pyplot as plt +import torchvision.utils as vutils +from pose_classifier import PoseClassifier + +class NuScenesVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, height=480, width=832, condition_frames=20, target_frames=10, default_text_prompt="A car driving scene captured by front camera", is_i2v=False): + self.base_path = base_path + self.samples_path = os.path.join(base_path, "samples") + + # Get all sample directories + self.sample_dirs = [] + for item in os.listdir(self.samples_path): + sample_path = os.path.join(self.samples_path, item) + if os.path.isdir(sample_path): + # Check if required files exist + condition_path = os.path.join(sample_path, "condition.mp4") + target_path = os.path.join(sample_path, "target.mp4") + poses_path = os.path.join(sample_path, "poses.json") + + if all(os.path.exists(p) for p in [condition_path, target_path, poses_path]): + self.sample_dirs.append(sample_path) + + print(f"Found {len(self.sample_dirs)} valid samples in NuScenes dataset.") + + self.height = height + self.width = width + self.condition_frames = condition_frames + self.target_frames = target_frames + self.default_text_prompt = default_text_prompt + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + def load_video_frames(self, video_path): + reader = imageio.get_reader(video_path) + frames = [] + + for frame_data in reader: + frame = Image.fromarray(frame_data) + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frames.append(frame) + + reader.close() + + if len(frames) == 0: + return None + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + return frames + + def __getitem__(self, data_id): + sample_dir = self.sample_dirs[data_id] + + # Load condition video (first 20 frames) + condition_path = os.path.join(sample_dir, "condition.mp4") + condition_video = self.load_video_frames(condition_path) + + # Load target video (next 10 frames) + target_path = os.path.join(sample_dir, "target.mp4") + target_video = self.load_video_frames(target_path) + + # Use default text prompt + text_prompt = self.default_text_prompt + + # Concatenate condition and target videos + if condition_video is not None and target_video is not None: + full_video = torch.cat([condition_video, target_video], dim=1) # Concatenate along time dimension + else: + return self.__getitem__((data_id + 1) % len(self.sample_dirs)) + + data = { + "text": text_prompt, + "video": full_video, + "condition_video": condition_video, + "target_video": target_video, + "path": sample_dir + } + + if self.is_i2v: + # Use first frame of condition video as reference image + first_frame = condition_video[:, 0, :, :] # C H W + first_frame_pil = v2.ToPILImage()(first_frame * 0.5 + 0.5) # Denormalize + data["first_frame"] = np.array(first_frame_pil) + + return data + + def __len__(self): + return len(self.sample_dirs) + +class NuScenesTensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, steps_per_epoch, condition_frames=20, target_frames=10): + self.base_path = base_path + self.samples_path = os.path.join(base_path, "samples") + self.condition_frames = condition_frames + self.target_frames = target_frames + self.pose_classifier = PoseClassifier() + + # Find all samples with encoded data + self.encoded_paths = [] + if os.path.exists(self.samples_path): + for item in os.listdir(self.samples_path): + if item.endswith(".recam.pth"): + encoded_path = os.path.join(self.samples_path, item) + self.encoded_paths.append(encoded_path) + + print(f"Found {len(self.encoded_paths)} encoded samples in NuScenes dataset.") + assert len(self.encoded_paths) > 0, "No encoded data found!" + + self.steps_per_epoch = steps_per_epoch + self.skip = 0 + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """ + 计算相对于参考帧的相对旋转。 + Args: + current_rotation: 当前帧的四元数 (q_current) [4] + reference_rotation: 参考帧的四元数 (q_ref) [4] + Returns: + relative_rotation: 相对旋转的四元数 [4] + """ + # 将四元数转换为 PyTorch 张量 + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + # 计算参考旋转的逆 (q_ref^-1) + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def process_poses(self, poses_path): + """Process poses to create camera embeddings""" + with open(poses_path, 'r') as f: + poses_data = json.load(f) + + target_relative_poses = poses_data['target_relative_poses'] + + # Generate pose embeddings for target frames + pose_embeddings = [] + + if len(target_relative_poses) == 0: + # If no target poses, use zero vectors + for i in range(self.target_frames): + pose_vec = torch.zeros(7, dtype=torch.float32) # 3 translation + 4 rotation + pose_embeddings.append(pose_vec) + else: + # Create pose vectors for target frames + for i in range(self.target_frames): + if len(target_relative_poses) == 1: + # Use the single pose for all frames + pose_data = target_relative_poses[0] + else: + # Simple selection - use closest pose or interpolate indices + pose_idx = min(i * len(target_relative_poses) // self.target_frames, + len(target_relative_poses) - 1) + pose_data = target_relative_poses[pose_idx] + + # Extract translation (3D) and rotation (4D quaternion) + translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32) + rotation = self.calculate_relative_rotation( + current_rotation=pose_data['current_rotation'], + reference_rotation=pose_data['reference_rotation'] + ) + + # Concatenate to form 7D pose vector + pose_vec = torch.cat([translation, rotation], dim=0) # [7] + pose_embeddings.append(pose_vec) + + # Stack pose embeddings + pose_embedding = torch.stack(pose_embeddings, dim=0) # [target_frames, 7] + + pose_analysis = self.pose_classifier.analyze_pose_sequence(pose_embedding) + pose_classes = pose_analysis['classifications'] + + # 创建类别embedding + class_embeddings = self.pose_classifier.create_class_embedding( + pose_classes, embed_dim=512 + ) + + return { + 'raw_poses': pose_embedding, + 'pose_classes': pose_classes, + 'class_embeddings': class_embeddings, + 'pose_analysis': pose_analysis + } + + def __getitem__(self, index): + while True: + try: + data_id = torch.randint(0, len(self.encoded_paths), (1,))[0] + data_id = (data_id + index) % len(self.encoded_paths) + + encoded_path = self.encoded_paths[data_id] + data = torch.load(encoded_path, weights_only=True, map_location="cpu") + + # Get poses path + sample_name = os.path.basename(encoded_path).replace(".recam.pth", "") + poses_path = os.path.join(self.samples_path, sample_name, "poses.json") + + if not os.path.exists(poses_path): + raise FileNotFoundError(f"poses.json not found for sample {sample_name}") + + pose_data = self.process_poses(poses_path) + + # pose_analysis = pose_data['pose_analysis'] + # class_distribution = pose_analysis['class_distribution'] + # if class_distribution["backward"] > 0 or class_distribution["forward"] > 0: + # index = (index + 1) % len(self.encoded_paths) + # self.skip += 1 + # print(f"skip {self.skip}") + # continue + + result = { + "latents": data["latents"], + "prompt_emb": data["prompt_emb"], + "image_emb": data.get("image_emb", {}), + "camera": pose_data['class_embeddings'].to(torch.bfloat16), # 使用类别embedding + "pose_classes": pose_data['pose_classes'], # 保留类别标签用于分析 + "raw_poses": pose_data['raw_poses'], # 保留原始pose用于对比 + "pose_analysis": pose_data['pose_analysis'] # 保留分析信息 + } + + break + + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.encoded_paths)) + + return result + + def __len__(self): + return self.steps_per_epoch + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".recam.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + print(f"Output: {pth_path}") + else: + print(f"File {pth_path} already exists, skipping.") + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(512, dim) # Changed from 12 to 7 + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "nus/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def visualize_training_results(self, latents, noisy_latents, noise_pred, training_target, step): + """可视化训练结果""" + try: + with torch.no_grad(): + # 分离target和condition部分 + tgt_latent_len = 5 + + # 提取各部分latents + target_latents = latents[:, :, tgt_latent_len:, :, :] # 原始target + condition_latents = latents[:, :, :tgt_latent_len, :, :] # condition + noisy_target_latents = noisy_latents[:, :, tgt_latent_len:, :, :] # 加噪target + + # 解码为视频帧 (取第一个batch) + # 只可视化前几帧以节省内存 + vis_frames = 10 + + # 解码condition frames + condition_sample = condition_latents[:, :, :vis_frames, :, :] + condition_video = self.pipe.decode_video(condition_sample, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + condition_video = condition_video[0].to(torch.float32) # [C, T, H, W] + + # 解码原始target frames + target_sample = target_latents[:, :, :vis_frames, :, :] + target_video = self.pipe.decode_video(target_sample, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + target_video = target_video[0].to(torch.float32) # [C, T, H, W] + + # 解码加噪target frames + noisy_target_sample = noisy_target_latents[:, :, :vis_frames, :, :] + noisy_target_video = self.pipe.decode_video(noisy_target_sample, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + noisy_target_video = noisy_target_video[0].to(torch.float32) # [C, T, H, W] + + # 解码预测结果 (从noise_pred重构) + pred_latents = noisy_target_latents - noise_pred[:, :, 5:, :, :] + pred_video = self.pipe.decode_video(pred_latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) + pred_video = pred_video[0].to(torch.float32) # [C, T, H, W] + + # 归一化到[0,1] + condition_video = (condition_video * 0.5 + 0.5).clamp(0, 1) + target_video = (target_video * 0.5 + 0.5).clamp(0, 1) + noisy_target_video = (noisy_target_video * 0.5 + 0.5).clamp(0, 1) + pred_video = (pred_video * 0.5 + 0.5).clamp(0, 1) + + # 创建可视化图像 + fig, axes = plt.subplots(4, vis_frames, figsize=(vis_frames * 3, 12)) + if vis_frames == 1: + axes = axes.reshape(-1, 1) + + for frame_idx in range(vis_frames-1): + # Condition frame + condition_frame = condition_video[:, frame_idx, :, :].permute(1, 2, 0).cpu().numpy() + axes[0, frame_idx].imshow(condition_frame) + axes[0, frame_idx].set_title(f'Condition Frame {frame_idx}') + axes[0, frame_idx].axis('off') + + # Original target frame + target_frame = target_video[:, frame_idx, :, :].permute(1, 2, 0).cpu().numpy() + axes[1, frame_idx].imshow(target_frame) + axes[1, frame_idx].set_title(f'Original Target {frame_idx}') + axes[1, frame_idx].axis('off') + + # Noisy target frame + noisy_frame = noisy_target_video[:, frame_idx, :, :].permute(1, 2, 0).cpu().numpy() + axes[2, frame_idx].imshow(noisy_frame) + axes[2, frame_idx].set_title(f'Noisy Target {frame_idx}') + axes[2, frame_idx].axis('off') + + # Predicted frame + pred_frame = pred_video[:, frame_idx, :, :].permute(1, 2, 0).cpu().numpy() + axes[3, frame_idx].imshow(pred_frame) + axes[3, frame_idx].set_title(f'Prediction {frame_idx}') + axes[3, frame_idx].axis('off') + + plt.tight_layout() + + # 保存图像 + save_path = os.path.join(self.vis_dir, f"training_step_{step:06d}.png") + plt.savefig(save_path, dpi=100, bbox_inches='tight') + plt.close() + + # 记录到wandb + if wandb.run is not None: + wandb.log({ + "training_visualization": wandb.Image(save_path), + "step": step + }) + + print(f"Visualization saved to {save_path}") + + except Exception as e: + print(f"Error during visualization: {e}") + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + # 裁剪空间尺寸 (例如裁剪到固定的 height 和 width) + target_height, target_width = 50, 70 # 根据你的需求调整 + current_height, current_width = latents.shape[3], latents.shape[4] + + if current_height > target_height or current_width > target_width: + # 中心裁剪 + h_start = (current_height - target_height) // 2 + w_start = (current_width - target_width) // 2 + latents = latents[:, :, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + tgt_latent_len = 10 + noisy_latents[:, :, :tgt_latent_len, ...] = origin_latents[:, :, :tgt_latent_len, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred[:, :, tgt_latent_len:, ...].float(), training_target[:, :, tgt_latent_len:, ...].float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # 可视化 (每10步一次) + if self.global_step % 1000 == 500: + self.visualize_training_results( + latents=origin_latents, + noisy_latents=noisy_latents, + noise_pred=noise_pred, + training_target=training_target, + step=self.global_step + ) + + # Record log + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "global_step": self.global_step + }) + + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/nus" + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_8040.ckpt")) + +def data_process(args): + if args.dataset_type == "nuscenes": + dataset = NuScenesVideoDataset( + args.dataset_path, + height=args.height, + width=args.width, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + default_text_prompt=args.default_text_prompt, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + if args.dataset_type == "nuscenes": + dataset = NuScenesTensorDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + if args.use_swanlab: + project_name = "nuscenes-recam" if args.dataset_type == "nuscenes" else "recam" + wandb.init( + project=project_name, + name=f"{args.dataset_type}-video-generation", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + choices=["original", "nuscenes"], + help="Type of dataset. 'original' for the original format, 'nuscenes' for NuScenes format.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_3", + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/nus_checkpoint", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=True, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=1000, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--condition_frames", + type=int, + default=20, + help="Number of condition frames for NuScenes dataset.", + ) + parser.add_argument( + "--target_frames", + type=int, + default=10, + help="Number of target frames for NuScenes dataset.", + ) + parser.add_argument( + "--height", + type=int, + default=900, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=1600, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=2, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=True, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default="cloud", + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + parser.add_argument( + "--default_text_prompt", + type=str, + default="A car driving scene captured by front camera", + help="Default text prompt for NuScenes samples without description.", + ) + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_nus_dynamic.py b/scripts/train_nus_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..00acf9f00f7d7e1ae9e3217bf93f5766771db927 --- /dev/null +++ b/scripts/train_nus_dynamic.py @@ -0,0 +1,734 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +class DynamicNuScenesDataset(torch.utils.data.Dataset): + """支持动态历史长度的NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = os.path.join(base_path, "scenes") + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + # 检查是否有编码的tensor文件 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + + # print(f"Found {len(self.scene_dirs)} scenes with encoded data") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 预处理设置 + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment(self, scene_info): + """动态选择条件帧和目标帧 - 修正版本处理VAE时间压缩""" + keyframe_indices = scene_info['keyframe_indices'] # 原始帧索引 + total_frames = scene_info['total_frames'] # 原始总帧数 + + if len(keyframe_indices) < 2: + print('error1____________') + return None + + # 🔧 计算压缩后的帧数 + compressed_total_frames = total_frames // self.time_compression_ratio + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in keyframe_indices] + + # print(f"原始总帧数: {total_frames}, 压缩后: {compressed_total_frames}") + # print(f"原始关键帧: {keyframe_indices[:5]}..., 压缩后: {compressed_keyframe_indices[:5]}...") + + # 随机选择条件帧长度(基于压缩后的帧数) + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + print('ratio:',ratio) + if ratio<0.15: + condition_frames_compressed = 1 + elif 0.15<=ratio<0.3: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if compressed_total_frames < min_required_frames: + print(f"压缩后帧数不足: {compressed_total_frames} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = compressed_total_frames - min_required_frames + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 关键修复:在压缩空间中查找关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame_compressed <= idx < condition_end_compressed] + + target_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if condition_end_compressed <= idx < target_end_compressed] + + if not condition_keyframes_compressed: + print(f"条件段内无关键帧: {start_frame_compressed}-{condition_end_compressed}") + return None + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = max(condition_keyframes_compressed) + + # 🔧 找到对应的原始关键帧索引用于pose查找 + reference_keyframe_original_idx = None + for i, compressed_idx in enumerate(compressed_keyframe_indices): + if compressed_idx == reference_keyframe_compressed: + reference_keyframe_original_idx = i + break + + if reference_keyframe_original_idx is None: + print(f"无法找到reference关键帧的原始索引") + return None + + # 找到目标段对应的原始关键帧索引 + target_keyframes_original_indices = [] + for compressed_idx in target_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + target_keyframes_original_indices.append(i) + break + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'reference_keyframe_idx': reference_keyframe_original_idx, # 原始关键帧索引 + 'target_keyframe_indices': target_keyframes_original_indices, # 原始关键帧索引列表 + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + + def create_pose_embeddings(self, scene_info, segment_info): + """创建pose embeddings - 修正版本,包含condition和target的实际pose""" + keyframe_poses = scene_info['keyframe_poses'] + reference_keyframe_idx = segment_info['reference_keyframe_idx'] + target_keyframe_indices = segment_info['target_keyframe_indices'] + + if reference_keyframe_idx >= len(keyframe_poses): + return None + + reference_pose = keyframe_poses[reference_keyframe_idx] + + # 🔧 关键修复:pose向量应该包含condition帧和target帧的实际pose数据 + condition_frames = segment_info['condition_frames'] # 压缩后的condition帧数 + target_frames = segment_info['target_frames'] # 压缩后的target帧数 + total_frames = condition_frames + target_frames # 总帧数,与latent对齐 + + print(f"创建pose embedding: condition_frames={condition_frames}, target_frames={target_frames}, total_frames={total_frames}") + + # 🔧 获取condition段的关键帧索引 + start_frame = segment_info['start_frame'] + condition_end_compressed = start_frame + condition_frames + + # 压缩后的关键帧索引 + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in scene_info['keyframe_indices']] + + # 找到condition段的关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame <= idx < condition_end_compressed] + + # 找到对应的原始关键帧索引 + condition_keyframes_original_indices = [] + for compressed_idx in condition_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + condition_keyframes_original_indices.append(i) + break + + pose_vecs = [] + frame_types = [] # 新增:记录每帧是condition还是target + + # 🔧 前面的condition帧使用实际的pose数据 + for i in range(condition_frames): + if not condition_keyframes_original_indices: + # 如果condition段没有关键帧,使用reference pose + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 单位四元数 + else: + # 为condition帧分配pose + if len(condition_keyframes_original_indices) == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + # 线性插值选择关键帧 + if condition_frames == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + interp_ratio = i / (condition_frames - 1) + interp_idx = int(interp_ratio * (len(condition_keyframes_original_indices) - 1)) + keyframe_idx = condition_keyframes_original_indices[interp_idx] + + if keyframe_idx >= len(keyframe_poses): + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + condition_pose = keyframe_poses[keyframe_idx] + + # 计算相对于reference的pose + translation = torch.tensor( + np.array(condition_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + condition_pose['rotation'], + reference_pose['rotation'] + ) + + rotation = relative_rotation + + # 🔧 添加frame type embedding:0表示condition + pose_vec = torch.cat([translation, rotation, torch.tensor([0.0], dtype=torch.float32)], dim=0) # [3+4+1=8D] + pose_vecs.append(pose_vec) + frame_types.append('condition') + + # 🔧 后面的target帧使用实际的pose数据 + if not target_keyframe_indices: + # 如果目标段没有关键帧,target帧使用零向量 + for i in range(target_frames): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), # translation + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), # rotation + torch.tensor([1.0], dtype=torch.float32) # frame type: 1表示target + ], dim=0) + pose_vecs.append(pose_vec) + frame_types.append('target') + else: + # 为每个target帧分配pose + for i in range(target_frames): + if len(target_keyframe_indices) == 1: + # 只有一个关键帧,所有target帧使用相同的pose + target_keyframe_idx = target_keyframe_indices[0] + else: + # 多个关键帧,线性插值选择 + if target_frames == 1: + # 只有一帧,使用第一个关键帧 + target_keyframe_idx = target_keyframe_indices[0] + else: + # 线性插值 + interp_ratio = i / (target_frames - 1) + interp_idx = int(interp_ratio * (len(target_keyframe_indices) - 1)) + target_keyframe_idx = target_keyframe_indices[interp_idx] + + if target_keyframe_idx >= len(keyframe_poses): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + torch.tensor([1.0], dtype=torch.float32) # target + ], dim=0) + else: + target_pose = keyframe_poses[target_keyframe_idx] + + # 计算相对pose + relative_translation = torch.tensor( + np.array(target_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + target_pose['rotation'], + reference_pose['rotation'] + ) + + # 🔧 添加frame type embedding:1表示target + pose_vec = torch.cat([ + relative_translation, + relative_rotation, + torch.tensor([1.0], dtype=torch.float32) + ], dim=0) # [8D] + + pose_vecs.append(pose_vec) + frame_types.append('target') + + if not pose_vecs: + print("❌ 没有生成任何pose向量") + return None + + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 8] + print(f"生成pose序列形状: {pose_sequence.shape}") + print(f"期望形状: [{total_frames}, 8]") + print(f"帧类型分布: {frame_types}") + + # 🔧 只对target部分进行分类分析(condition部分不需要分类) + target_pose_sequence = pose_sequence[condition_frames:, :7] # 只取target部分的前7维 + + if target_pose_sequence.numel() == 0: + print("❌ Target pose序列为空") + return None + + # 使用分类器分析target部分 + pose_analysis = self.pose_classifier.analyze_pose_sequence(target_pose_sequence) + + # 过滤掉backward样本 + class_distribution = pose_analysis['class_distribution'] + # if 'backward' in class_distribution and class_distribution['backward'] > 0: + # print(f"⚠️ 检测到backward运动,跳过样本") + # return None + + # 🔧 创建完整的类别embedding(包含condition和target) + # condition帧的类别标签设为forward(或者可以设为特殊的"condition"类别) + condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # 0表示forward/condition + target_classes = pose_analysis['classifications'] + + # 拼接condition和target的类别 + full_classes = torch.cat([condition_classes, target_classes], dim=0) + + # 🔧 创建enhanced class embedding,包含frame type信息 + class_embeddings = self.create_enhanced_class_embedding( + full_classes, pose_sequence, embed_dim=512 + ) + + print(f"最终class embedding形状: {class_embeddings.shape}") + print(f"期望形状: [{total_frames}, 512]") + + # 🔧 验证embedding形状是否正确 + if class_embeddings.shape[0] != total_frames: + print(f"❌ Embedding帧数不匹配: {class_embeddings.shape[0]} != {total_frames}") + return None + + return { + 'raw_poses': pose_sequence, # [total_frames, 8] 包含condition和target的实际pose + frame type + 'pose_classes': full_classes, # [total_frames] 包含condition和target的类别 + 'class_embeddings': class_embeddings, # [total_frames, 512] 增强的embedding + 'pose_analysis': pose_analysis, # 只包含target部分的分析 + 'condition_frames': condition_frames, + 'target_frames': target_frames, + 'frame_types': frame_types + } + + def create_enhanced_class_embedding(self, class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor: + """ + 创建增强的类别embedding,包含frame type和pose信息 + Args: + class_labels: [num_frames] 类别标签 + pose_sequence: [num_frames, 8] pose序列,最后一维是frame type + embed_dim: embedding维度 + Returns: + embeddings: [num_frames, embed_dim] + """ + num_classes = 4 + num_frames = len(class_labels) + + # 基础的方向embedding + direction_vectors = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], # forward: 主要x分量 + [-1.0, 0.0, 0.0, 0.0], # backward: 负x分量 + [0.0, 1.0, 0.0, 0.0], # left_turn: 主要y分量 + [0.0, -1.0, 0.0, 0.0], # right_turn: 负y分量 + ], dtype=torch.float32) + + # One-hot编码 + one_hot = torch.zeros(num_frames, num_classes) + one_hot.scatter_(1, class_labels.unsqueeze(1), 1) + + # 基于方向向量的基础embedding + base_embeddings = one_hot @ direction_vectors # [num_frames, 4] + + # 🔧 添加frame type信息 + frame_types = pose_sequence[:, -1] # 最后一维是frame type + frame_type_embeddings = torch.zeros(num_frames, 2) + frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition + frame_type_embeddings[:, 1] = (frame_types == 1).float() # target + + # 🔧 添加pose的几何信息 + translations = pose_sequence[:, :3] # [num_frames, 3] + rotations = pose_sequence[:, 3:7] # [num_frames, 4] + + # 组合所有特征 + combined_features = torch.cat([ + base_embeddings, # [num_frames, 4] 方向特征 + frame_type_embeddings, # [num_frames, 2] 帧类型特征 + translations, # [num_frames, 3] 位移特征 + rotations, # [num_frames, 4] 旋转特征 + ], dim=1) # [num_frames, 13] + + # 扩展到目标维度 + if embed_dim > 13: + # 使用线性变换扩展 + expand_matrix = torch.randn(13, embed_dim) * 0.1 + # 保持重要特征 + expand_matrix[:13, :13] = torch.eye(13) + embeddings = combined_features @ expand_matrix + else: + embeddings = combined_features[:, :embed_dim] + + return embeddings + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video-480p.pth"), + weights_only=True, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + # print(f"场景 {os.path.basename(scene_dir)}: 原始帧数={scene_info['total_frames']}, " + # f"预期latent帧数={expected_latent_frames}, 实际latent帧数={actual_latent_frames}") + + if abs(actual_latent_frames - expected_latent_frames) > 2: # 允许小的舍入误差 + print(f"⚠️ Latent帧数不匹配,跳过此样本") + continue + + # 动态选择段落 + segment_info = self.select_dynamic_segment(scene_info) + if segment_info is None: + continue + + # 创建pose embeddings + pose_data = self.create_pose_embeddings(scene_info, segment_info) + if pose_data is None: + continue + + # 🔧 使用压缩后的索引提取latent段落 + start_frame = segment_info['start_frame'] # 已经是压缩后的索引 + condition_frames = segment_info['condition_frames'] # 已经是压缩后的帧数 + target_frames = segment_info['target_frames'] # 已经是压缩后的帧数 + + # print(f"提取latent段落: start={start_frame}, condition={condition_frames}, target={target_frames}") + # print(f"Full latents shape: {full_latents.shape}") + + # 确保索引不越界 + if start_frame + condition_frames + target_frames > full_latents.shape[1]: + print(f"索引越界,跳过: {start_frame + condition_frames + target_frames} > {full_latents.shape[1]}") + continue + + condition_latents = full_latents[:, start_frame:start_frame+condition_frames, :, :] + + target_latents = full_latents[:, start_frame+condition_frames:start_frame+condition_frames+target_frames, :, :] + + # print(f"Condition latents shape: {condition_latents.shape}") + # print(f"Target latents shape: {target_latents.shape}") + + # 拼接latents [condition, target] + combined_latents = torch.cat([condition_latents, target_latents], dim=1) + + result = { + "latents": combined_latents, + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "camera": pose_data['class_embeddings'].to(torch.bfloat16), + "pose_classes": pose_data['pose_classes'], + "raw_poses": pose_data['raw_poses'], + "pose_analysis": pose_data['pose_analysis'], + "condition_frames": condition_frames, # 压缩后的帧数 + "target_frames": target_frames, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(512, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "nus/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + # 获取动态长度信息(这些已经是压缩后的帧数) + condition_frames = batch["condition_frames"][0].item() # 压缩后的condition长度 + target_frames = batch["target_frames"][0].item() # 压缩后的target长度 + + # 🔧 获取原始帧数用于日志记录 + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + # Data + latents = batch["latents"].to(self.device) + # print(f"压缩后condition帧数: {condition_frames}, target帧数: {target_frames}") + # print(f"原始condition帧数: {original_condition_frames}, target帧数: {original_target_frames}") + # print(f"Latents shape: {latents.shape}") + + # 裁剪空间尺寸以节省内存 + # target_height, target_width = 50, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + # print(f"裁剪后latents shape: {latents.shape}") + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 关键:使用压缩后的condition长度 + # condition部分保持clean,只对target部分加噪 + noisy_latents[:, :, :condition_frames, ...] = origin_latents[:, :, :condition_frames, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + # print(f"targe尺寸: {training_target.shape}") + # 预测噪声 + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + # print(f"pred尺寸: {training_target.shape}") + # 🔧 只对target部分计算loss(使用压缩后的索引) + target_noise_pred = noise_pred[:, :, condition_frames:condition_frames+target_frames, ...] + target_training_target = training_target[:, :, condition_frames:condition_frames+target_frames, ...] + + loss = torch.nn.functional.mse_loss(target_noise_pred.float(), target_training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:',loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, # 压缩后的帧数000 + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, # 原始帧数 + "target_frames_original": original_target_frames, + "total_frames_compressed": condition_frames + target_frames, + "total_frames_original": original_condition_frames + original_target_frames, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/nus_dynamic" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_dynamic.ckpt")) + print(f"Saved dynamic model checkpoint: step{current_step}_dynamic.ckpt") + +def train_dynamic(args): + """训练支持动态历史长度的模型""" + dataset = DynamicNuScenesDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-recam", + name=f"dynamic-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=3000) + parser.add_argument("--max_epochs", type=int, default=10) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_nus_framepack.py b/scripts/train_nus_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..db3db5d306e5b873f4dacd9d593ae2d70b1b33f8 --- /dev/null +++ b/scripts/train_nus_framepack.py @@ -0,0 +1,840 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +class DynamicNuScenesDataset(torch.utils.data.Dataset): + """支持FramePack机制的动态历史长度NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = os.path.join(base_path, "scenes") + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + # 检查是否有编码的tensor文件 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment_framepack(self, scene_info): + """🔧 FramePack风格的动态选择条件帧和目标帧""" + keyframe_indices = scene_info['keyframe_indices'] # 原始帧索引 + total_frames = scene_info['total_frames'] # 原始总帧数 + + if len(keyframe_indices) < 2: + print('error1____________') + return None + + # 🔧 计算压缩后的帧数 + compressed_total_frames = total_frames // self.time_compression_ratio + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in keyframe_indices] + + # 随机选择条件帧长度(基于压缩后的帧数) + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + # 🔧 FramePack风格的采样策略 + ratio = random.random() + print('ratio:', ratio) + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if compressed_total_frames < min_required_frames: + print(f"压缩后帧数不足: {compressed_total_frames} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = compressed_total_frames - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 关键修复:在压缩空间中查找关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame_compressed <= idx < condition_end_compressed] + + target_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if condition_end_compressed <= idx < target_end_compressed] + + if not condition_keyframes_compressed: + print(f"条件段内无关键帧: {start_frame_compressed}-{condition_end_compressed}") + return None + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = max(condition_keyframes_compressed) + + # 🔧 找到对应的原始关键帧索引用于pose查找 + reference_keyframe_original_idx = None + for i, compressed_idx in enumerate(compressed_keyframe_indices): + if compressed_idx == reference_keyframe_compressed: + reference_keyframe_original_idx = i + break + + if reference_keyframe_original_idx is None: + print(f"无法找到reference关键帧的原始索引") + return None + + # 找到目标段对应的原始关键帧索引 + target_keyframes_original_indices = [] + for compressed_idx in target_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + target_keyframes_original_indices.append(i) + break + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'reference_keyframe_idx': reference_keyframe_original_idx, # 原始关键帧索引 + 'target_keyframe_indices': target_keyframes_original_indices, # 原始关键帧索引列表 + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + + # 🔧 FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + } + + def create_pose_embeddings(self, scene_info, segment_info): + """🔧 为所有帧(condition + target)创建pose embeddings - FramePack风格""" + keyframe_poses = scene_info['keyframe_poses'] + reference_keyframe_idx = segment_info['reference_keyframe_idx'] + target_keyframe_indices = segment_info['target_keyframe_indices'] + + if reference_keyframe_idx >= len(keyframe_poses): + return None + + reference_pose = keyframe_poses[reference_keyframe_idx] + + # 🔧 为所有帧(condition + target)计算pose embeddings + start_frame = segment_info['start_frame'] + condition_end_compressed = start_frame + segment_info['condition_frames'] + target_end_compressed = condition_end_compressed + segment_info['target_frames'] + + # 压缩后的关键帧索引 + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in scene_info['keyframe_indices']] + + # 找到condition段的关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame <= idx < condition_end_compressed] + + # 找到对应的原始关键帧索引 + condition_keyframes_original_indices = [] + for compressed_idx in condition_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + condition_keyframes_original_indices.append(i) + break + + pose_vecs = [] + frame_types = [] + + # 🔧 为condition帧计算pose + for i in range(segment_info['condition_frames']): + if not condition_keyframes_original_indices: + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + # 为condition帧分配pose + if len(condition_keyframes_original_indices) == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + if segment_info['condition_frames'] == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + interp_ratio = i / (segment_info['condition_frames'] - 1) + interp_idx = int(interp_ratio * (len(condition_keyframes_original_indices) - 1)) + keyframe_idx = condition_keyframes_original_indices[interp_idx] + + if keyframe_idx >= len(keyframe_poses): + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + condition_pose = keyframe_poses[keyframe_idx] + + translation = torch.tensor( + np.array(condition_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + condition_pose['rotation'], + reference_pose['rotation'] + ) + + rotation = relative_rotation + + # 🔧 添加frame type embedding:0表示condition + pose_vec = torch.cat([translation, rotation, torch.tensor([0.0], dtype=torch.float32)], dim=0) # [8D] + pose_vecs.append(pose_vec) + frame_types.append('condition') + + # 🔧 为target帧计算pose + if not target_keyframe_indices: + for i in range(segment_info['target_frames']): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + torch.tensor([1.0], dtype=torch.float32) # frame type: 1表示target + ], dim=0) + pose_vecs.append(pose_vec) + frame_types.append('target') + else: + for i in range(segment_info['target_frames']): + if len(target_keyframe_indices) == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + if segment_info['target_frames'] == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + interp_ratio = i / (segment_info['target_frames'] - 1) + interp_idx = int(interp_ratio * (len(target_keyframe_indices) - 1)) + target_keyframe_idx = target_keyframe_indices[interp_idx] + + if target_keyframe_idx >= len(keyframe_poses): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + torch.tensor([1.0], dtype=torch.float32) + ], dim=0) + else: + target_pose = keyframe_poses[target_keyframe_idx] + + relative_translation = torch.tensor( + np.array(target_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + target_pose['rotation'], + reference_pose['rotation'] + ) + + # 🔧 添加frame type embedding:1表示target + pose_vec = torch.cat([ + relative_translation, + relative_rotation, + torch.tensor([1.0], dtype=torch.float32) + ], dim=0) + + pose_vecs.append(pose_vec) + frame_types.append('target') + + if not pose_vecs: + print("❌ 没有生成任何pose向量") + return None + + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 8] + + # 🔧 只对target部分进行分类分析 + target_pose_sequence = pose_sequence[segment_info['condition_frames']:, :7] + + if target_pose_sequence.numel() == 0: + print("❌ Target pose序列为空") + return None + + # 使用分类器分析target部分 + pose_analysis = self.pose_classifier.analyze_pose_sequence(target_pose_sequence) + + # 🔧 创建完整的类别embedding + condition_classes = torch.full((segment_info['condition_frames'],), 0, dtype=torch.long) + target_classes = pose_analysis['classifications'] + + full_classes = torch.cat([condition_classes, target_classes], dim=0) + + # 🔧 创建enhanced class embedding + class_embeddings = self.create_enhanced_class_embedding( + full_classes, pose_sequence, embed_dim=512 + ) + + return class_embeddings + + def create_enhanced_class_embedding(self, class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor: + """创建增强的类别embedding""" + num_classes = 4 + num_frames = len(class_labels) + + direction_vectors = torch.tensor([ + [1.0, 0.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + ], dtype=torch.float32) + + one_hot = torch.zeros(num_frames, num_classes) + one_hot.scatter_(1, class_labels.unsqueeze(1), 1) + + base_embeddings = one_hot @ direction_vectors + + frame_types = pose_sequence[:, -1] + frame_type_embeddings = torch.zeros(num_frames, 2) + frame_type_embeddings[:, 0] = (frame_types == 0).float() + frame_type_embeddings[:, 1] = (frame_types == 1).float() + + translations = pose_sequence[:, :3] + rotations = pose_sequence[:, 3:7] + + combined_features = torch.cat([ + base_embeddings, + frame_type_embeddings, + translations, + rotations, + ], dim=1) + + if embed_dim > 13: + expand_matrix = torch.randn(13, embed_dim) * 0.1 + expand_matrix[:13, :13] = torch.eye(13) + embeddings = combined_features @ expand_matrix + else: + embeddings = combined_features[:, :embed_dim] + + return embeddings + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入""" + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) + + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) + + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 移除batch维度 + if B == 1: + main_latents = main_latents.squeeze(0) + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, + 'clean_latent_4x_indices': clean_latent_4x_indices_final, + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video-480p.pth"), + weights_only=True, + map_location="cpu" + ) + + full_latents = encoded_data['latents'] # [C, T, H, W] + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + if abs(actual_latent_frames - expected_latent_frames) > 2: + print(f"⚠️ Latent帧数不匹配,跳过此样本") + continue + + # 🔧 使用FramePack风格的段落选择 + segment_info = self.select_dynamic_segment_framepack(scene_info) + if segment_info is None: + continue + + # 🔧 创建pose embeddings + pose_embeddings = self.create_pose_embeddings(scene_info, segment_info) + if pose_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 添加mask到pose embeddings + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + camera_with_mask = torch.cat([pose_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], + "clean_latents": framepack_inputs['clean_latents'], + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # Camera数据 + "camera": camera_with_mask, + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelFuture) + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 🔧 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder + self.add_framepack_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(513, dim) # 512 + 1 for mask + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "nus/visualizations_dynamic_framepack" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 🔧 准备FramePack风格的输入 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: + latents = latents.unsqueeze(0) + + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 索引 + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # Camera embedding + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 + if random.random() < camera_dropout_prob: + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, + # FramePack风格的条件输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:', loss) + + # 记录信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, + "target_frames_original": original_target_frames, + "has_clean_latents": clean_latents is not None, + "has_clean_latents_2x": clean_latents_2x is not None, + "has_clean_latents_4x": clean_latents_4x is not None, + "total_frames_compressed": target_frames, + "total_frames_original": original_target_frames, + "scene_name": scene_name, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/nus_dynamic_framepack" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_framepack.ckpt")) + print(f"Saved FramePack model checkpoint: step{current_step}_framepack.ckpt") + +def train_dynamic(args): + """训练支持FramePack机制的动态历史长度模型""" + dataset = DynamicNuScenesDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-framepack-recam", + name=f"framepack-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train FramePack Dynamic ReCamMaster for NuScenes") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=3000) + parser.add_argument("--max_epochs", type=int, default=10) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + print("🔧 使用FramePack风格训练NuScenes数据集:") + print(f" - 支持多尺度下采样(1x/2x/4x)") + print(f" - 使用WanModelFuture模型") + print(f" - 数据集路径: {args.dataset_path}") + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_openx.py b/scripts/train_openx.py new file mode 100644 index 0000000000000000000000000000000000000000..fad49e18c048a112b769586be6b6aadcb8e72122 --- /dev/null +++ b/scripts/train_openx.py @@ -0,0 +1,641 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + """ + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +class OpenXFramePackDataset(torch.utils.data.Dataset): + """OpenX数据集的FramePack训练数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=480, width=832): + + self.base_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的episode + self.episode_dirs = [] + print(f"🔧 扫描OpenX数据集: {base_path}") + + if os.path.exists(base_path): + for item in os.listdir(base_path): + episode_dir = os.path.join(base_path, item) + if os.path.isdir(episode_dir): + encoded_path = os.path.join(episode_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.episode_dirs.append(episode_dir) + + print(f" ✅ 找到 {len(self.episode_dirs)} 个episodes") + else: + print(f" ⚠️ 路径不存在: {base_path}") + + assert len(self.episode_dirs) > 0, "No encoded episodes found!" + + def select_dynamic_segment_framepack(self, full_latents): + """🔧 FramePack风格的动态选择条件帧和目标帧 - 适配OpenX数据""" + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(total_lens-target_frames_compressed-1, max_condition_compressed) + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9 or total_lens <= 2*target_frames_compressed + 1: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 🔧 根据实际的condition_frames_compressed生成索引 + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + def create_pose_embeddings(self, cam_data, segment_info): + """🔧 创建pose embeddings - 为所有帧(condition + target)提取camera信息""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + keyframe_idx = compressed_idx * 4 + if keyframe_idx + 4 < len(cam_data_seq): + all_keyframe_indices.append(keyframe_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + if idx + 4 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用单位矩阵 + identity_cam = torch.eye(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 适配OpenX数据""" + # 🔧 处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, + 'clean_latent_4x_indices': clean_latent_4x_indices_final, + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个episode + episode_dir = random.choice(self.episode_dirs) + episode_name = os.path.basename(episode_dir) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(episode_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + full_latents = encoded_data['latents'] # [C, T, H, W] + if full_latents.shape[1] <= 10: + continue + cam_data = encoded_data['cam_emb'] + + # 🔧 使用FramePack风格的段落选择 + segment_info = self.select_dynamic_segment_framepack(full_latents) + if segment_info is None: + continue + + # 🔧 为所有帧创建pose embeddings + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) + if all_camera_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 # condition帧标记为1 + mask = mask.view(-1, 1) + + # 添加mask到camera embeddings + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], # 主要预测目标 + "clean_latents": framepack_inputs['clean_latents'], # 条件帧(2帧) + "clean_latents_2x": framepack_inputs['clean_latents_2x'], # 2x条件帧(2帧,不足用0填充) + "clean_latents_4x": framepack_inputs['clean_latents_4x'], # 4x条件帧(16帧,不足用0填充) + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # 🔧 直接传递带mask的camera embeddings + "camera": camera_with_mask, # 所有帧的camera embeddings(带mask) + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + "condition_frames": n, + "target_frames": m, + "episode_name": episode_name, + "dataset_name": "openx-fractal", + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample from {episode_dir}: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class OpenXLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder + self.add_framepack_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "openx_training/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计,但适配OpenX数据的分辨率 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤 - 适配OpenX数据""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + episode_name = batch.get("episode_name", ["unknown"])[0] + + # 🔧 准备FramePack风格的输入 - 确保有batch维度 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 条件输入(处理空张量和维度) + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 🔧 索引(处理空张量) + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # 🔧 直接使用带mask的camera embeddings + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 # 10%概率丢弃camera条件 + if random.random() < camera_dropout_prob: + # 创建零camera embedding + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: # 80%概率添加噪声 + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, # 🔧 直接传递带mask的camera embeddings + # 🔧 FramePack风格的条件输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print(f"----------loss{loss}--------------") + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + print(f"Saved OpenX FramePack model checkpoint: step{current_step}.ckpt") + +def train_openx(args): + """训练OpenX数据集的FramePack模型""" + + dataset = OpenXFramePackDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = OpenXLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train OpenX Fractal Dataset with FramePack") + parser.add_argument("--dataset_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", + help="OpenX编码数据集路径") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=500) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=4, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step750.ckpt") + + args = parser.parse_args() + + print("🚀 开始训练OpenX Fractal数据集:") + print(f"📁 数据集路径: {args.dataset_path}") + print(f"🎯 条件帧范围: {args.min_condition_frames}-{args.max_condition_frames}") + print(f"🎯 目标帧数: {args.target_frames}") + + train_openx(args) \ No newline at end of file diff --git a/scripts/train_ori.py b/scripts/train_ori.py new file mode 100644 index 0000000000000000000000000000000000000000..df197d4a1dbb278678ca046d915840c77a676250 --- /dev/null +++ b/scripts/train_ori.py @@ -0,0 +1,646 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + while True: + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + break + except: + data_id += 1 + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".tensors.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + else: + print(f"File {pth_path} already exists, skipping.") + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + self.steps_per_epoch = steps_per_epoch + + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + + def get_relative_pose(self, cam_params): + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + + def __getitem__(self, index): + # Return: + # data['latents']: torch.Size([16, 21*2, 60, 104]) + # data['camera']: torch.Size([21, 3, 4]) + # data['prompt_emb']["context"][0]: torch.Size([512, 4096]) + while True: + try: + data = {} + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path_tgt = self.path[data_id] + data_tgt = torch.load(path_tgt, weights_only=True, map_location="cpu") + + # load the condition latent + match = re.search(r'cam(\d+)', path_tgt) + tgt_idx = int(match.group(1)) + cond_idx = random.randint(1, 10) + while cond_idx == tgt_idx: + cond_idx = random.randint(1, 10) + path_cond = re.sub(r'cam(\d+)', f'cam{cond_idx:02}', path_tgt) + data_cond = torch.load(path_cond, weights_only=True, map_location="cpu") + data['latents'] = torch.cat((data_tgt['latents'],data_cond['latents']),dim=1) + data['prompt_emb'] = data_tgt['prompt_emb'] + data['image_emb'] = {} + + # load the target trajectory + base_path = path_tgt.rsplit('/', 2)[0] + tgt_camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") + with open(tgt_camera_path, 'r') as file: + cam_data = json.load(file) + multiview_c2ws = [] + cam_idx = list(range(81))[::4] + for view_idx in [cond_idx, tgt_idx]: + traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{view_idx:02d}"]) for idx in cam_idx] + traj = np.stack(traj).transpose(0, 2, 1) + c2ws = [] + for c2w in traj: + c2w = c2w[:, [1, 2, 0, 3]] + c2w[:3, 1] *= -1. + c2w[:3, 3] /= 100 + c2ws.append(c2w) + multiview_c2ws.append(c2ws) + cond_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[0]] + tgt_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[1]] + relative_poses = [] + for i in range(len(tgt_cam_params)): + relative_pose = self.get_relative_pose([cond_cam_params[0], tgt_cam_params[i]]) + relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) + pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + data['camera'] = pose_embedding.to(torch.bfloat16) + break + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.path)) + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + tgt_latent_len = noisy_latents.shape[2] // 2 + noisy_latents[:, :, tgt_latent_len:, ...] = origin_latents[:, :, tgt_latent_len:, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred[:, :, :tgt_latent_len, ...].float(), training_target[:, :, :tgt_latent_len, ...].float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = self.trainer.checkpoint_callback.dirpath + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ReCamMaster") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, args.metadata_file_name), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=os.path.join(args.output_path, "swanlog"), + ) + logger = [swanlab_logger] + else: + logger = None + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=logger, + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_origin.py b/scripts/train_origin.py new file mode 100644 index 0000000000000000000000000000000000000000..e66ef10acbad63b1b7ea8f729a4acdaf8e23f437 --- /dev/null +++ b/scripts/train_origin.py @@ -0,0 +1,1263 @@ +#融合nuscenes和sekai数据集的MoE训练 +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +import json +import numpy as np +import random +import traceback +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +import argparse +from scipy.spatial.transform import Rotation as R + +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """计算相机B相对于相机A的相对位姿矩阵""" + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +class MultiDatasetDynamicDataset(torch.utils.data.Dataset): + """支持FramePack机制的多数据集动态历史长度数据集 - 融合nuscenes和sekai""" + + def __init__(self, dataset_configs, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + """ + Args: + dataset_configs: 数据集配置列表,每个配置包含 { + 'name': 数据集名称, + 'paths': 数据集路径列表, + 'type': 数据集类型 ('sekai' 或 'nuscenes'), + 'weight': 采样权重 + } + """ + self.dataset_configs = dataset_configs + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 + + # 🔧 扫描所有数据集,建立统一的场景索引 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景的数据集信息 + self.dataset_weights = [] # 每个场景的采样权重 + + total_scenes = 0 + + for config in self.dataset_configs: + dataset_name = config['name'] + dataset_paths = config['paths'] if isinstance(config['paths'], list) else [config['paths']] + dataset_type = config['type'] + dataset_weight = config.get('weight', 1.0) + + print(f"🔧 扫描数据集: {dataset_name} (类型: {dataset_type})") + + dataset_scenes = [] + for dataset_path in dataset_paths: + print(f" 📁 检查路径: {dataset_path}") + if os.path.exists(dataset_path): + if dataset_type == 'nuscenes': + # NuScenes使用 base_path/scenes 结构 + scenes_path = os.path.join(dataset_path, "scenes") + print(f" 📂 扫描NuScenes scenes目录: {scenes_path}") + for item in os.listdir(scenes_path): + scene_dir = os.path.join(scenes_path, item) + if os.path.isdir(scene_dir): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + + elif dataset_type in ['sekai', 'spatialvid', 'openx']: + # Sekai、spatialvid、OpenX等数据集直接扫描根目录 + for item in os.listdir(dataset_path): + scene_dir = os.path.join(dataset_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = { + 'name': dataset_name, + 'type': dataset_type, + 'weight': dataset_weight + } + self.dataset_weights.append(dataset_weight) + else: + print(f" ❌ 路径不存在: {dataset_path}") + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + total_scenes += len(dataset_scenes) + + # 统计各数据集场景数 + dataset_counts = {} + for scene_dir in self.scene_dirs: + dataset_name = self.dataset_info[scene_dir]['name'] + dataset_type = self.dataset_info[scene_dir]['type'] + key = f"{dataset_name} ({dataset_type})" + dataset_counts[key] = dataset_counts.get(key, 0) + 1 + + for dataset_key, count in dataset_counts.items(): + print(f" - {dataset_key}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 🔧 计算采样概率 + total_weight = sum(self.dataset_weights) + self.sampling_probs = [w / total_weight for w in self.dataset_weights] + + def select_dynamic_segment_nuscenes(self, scene_info): + """🔧 NuScenes专用的FramePack风格段落选择""" + keyframe_indices = scene_info['keyframe_indices'] # 原始帧索引 + total_frames = scene_info['total_frames'] # 原始总帧数 + + if len(keyframe_indices) < 2: + return None + + # 计算压缩后的帧数 + compressed_total_frames = total_frames // self.time_compression_ratio + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in keyframe_indices] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + # FramePack风格的采样策略 + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if compressed_total_frames < min_required_frames: + return None + + start_frame_compressed = random.randint(0, compressed_total_frames - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 NuScenes特有:查找关键帧索引 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame_compressed <= idx < condition_end_compressed] + + target_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if condition_end_compressed <= idx < target_end_compressed] + + if not condition_keyframes_compressed: + return None + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = max(condition_keyframes_compressed) + + # 找到对应的原始关键帧索引用于pose查找 + reference_keyframe_original_idx = None + for i, compressed_idx in enumerate(compressed_keyframe_indices): + if compressed_idx == reference_keyframe_compressed: + reference_keyframe_original_idx = i + break + + if reference_keyframe_original_idx is None: + return None + + # 找到目标段对应的原始关键帧索引 + target_keyframes_original_indices = [] + for compressed_idx in target_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + target_keyframes_original_indices.append(i) + break + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + + # 🔧 NuScenes特有数据 + 'reference_keyframe_idx': reference_keyframe_original_idx, + 'target_keyframe_indices': target_keyframes_original_indices, + } + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数 - NuScenes专用""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,正确处理空索引""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] # 注意维度顺序 + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def create_sekai_pose_embeddings(self, cam_data, segment_info): + """创建Sekai风格的pose embeddings""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_openx_pose_embeddings(self, cam_data, segment_info): + """🔧 创建OpenX风格的pose embeddings - 类似sekai但处理更短的序列""" + cam_data_seq = cam_data['extrinsic'] + + # 为所有帧计算相对pose - OpenX使用4倍间隔 + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + keyframe_idx = compressed_idx * 4 + if keyframe_idx + 4 < len(cam_data_seq): + all_keyframe_indices.append(keyframe_idx) + + relative_cams = [] + for idx in all_keyframe_indices: + if idx + 4 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用单位矩阵 + identity_cam = torch.eye(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_spatialvid_pose_embeddings(self, cam_data, segment_info): + """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + # SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = segment_info['keyframe_original_idx'] + + relative_cams = [] + for idx in keyframe_original_idx: + if idx + 1 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # SpatialVid: 每隔1帧 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用零运动 + identity_cam = torch.zeros(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def create_nuscenes_pose_embeddings_framepack(self, scene_info, segment_info): + """创建NuScenes风格的pose embeddings - FramePack版本(简化版本,直接7维)""" + keyframe_poses = scene_info['keyframe_poses'] + reference_keyframe_idx = segment_info['reference_keyframe_idx'] + target_keyframe_indices = segment_info['target_keyframe_indices'] + + if reference_keyframe_idx >= len(keyframe_poses): + return None + + reference_pose = keyframe_poses[reference_keyframe_idx] + + # 为所有帧(condition + target)创建pose embeddings + start_frame = segment_info['start_frame'] + condition_end_compressed = start_frame + segment_info['condition_frames'] + target_end_compressed = condition_end_compressed + segment_info['target_frames'] + + # 压缩后的关键帧索引 + compressed_keyframe_indices = [idx // self.time_compression_ratio for idx in scene_info['keyframe_indices']] + + # 找到condition段的关键帧 + condition_keyframes_compressed = [idx for idx in compressed_keyframe_indices + if start_frame <= idx < condition_end_compressed] + + # 找到对应的原始关键帧索引 + condition_keyframes_original_indices = [] + for compressed_idx in condition_keyframes_compressed: + for i, comp_idx in enumerate(compressed_keyframe_indices): + if comp_idx == compressed_idx: + condition_keyframes_original_indices.append(i) + break + + pose_vecs = [] + + # 为condition帧计算pose + for i in range(segment_info['condition_frames']): + if not condition_keyframes_original_indices: + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + # 为condition帧分配pose + if len(condition_keyframes_original_indices) == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + if segment_info['condition_frames'] == 1: + keyframe_idx = condition_keyframes_original_indices[0] + else: + interp_ratio = i / (segment_info['condition_frames'] - 1) + interp_idx = int(interp_ratio * (len(condition_keyframes_original_indices) - 1)) + keyframe_idx = condition_keyframes_original_indices[interp_idx] + + if keyframe_idx >= len(keyframe_poses): + translation = torch.zeros(3, dtype=torch.float32) + rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + else: + condition_pose = keyframe_poses[keyframe_idx] + + translation = torch.tensor( + np.array(condition_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + condition_pose['rotation'], + reference_pose['rotation'] + ) + + rotation = relative_rotation + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([translation, rotation], dim=0) # [7D] + pose_vecs.append(pose_vec) + + # 为target帧计算pose + if not target_keyframe_indices: + for i in range(segment_info['target_frames']): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + pose_vecs.append(pose_vec) + else: + for i in range(segment_info['target_frames']): + if len(target_keyframe_indices) == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + if segment_info['target_frames'] == 1: + target_keyframe_idx = target_keyframe_indices[0] + else: + interp_ratio = i / (segment_info['target_frames'] - 1) + interp_idx = int(interp_ratio * (len(target_keyframe_indices) - 1)) + target_keyframe_idx = target_keyframe_indices[interp_idx] + + if target_keyframe_idx >= len(keyframe_poses): + pose_vec = torch.cat([ + torch.zeros(3, dtype=torch.float32), + torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32), + ], dim=0) # [7D] + else: + target_pose = keyframe_poses[target_keyframe_idx] + + relative_translation = torch.tensor( + np.array(target_pose['translation']) - np.array(reference_pose['translation']), + dtype=torch.float32 + ) + + relative_rotation = self.calculate_relative_rotation( + target_pose['rotation'], + reference_pose['rotation'] + ) + + # 🔧 简化:直接7维 [translation(3) + rotation(4)] + pose_vec = torch.cat([relative_translation, relative_rotation], dim=0) # [7D] + + pose_vecs.append(pose_vec) + + if not pose_vecs: + return None + + pose_sequence = torch.stack(pose_vecs, dim=0) # [total_frames, 7] + + return pose_sequence + + # 修改select_dynamic_segment方法 + def select_dynamic_segment(self, full_latents, dataset_type, scene_info=None): + """🔧 根据数据集类型选择不同的段落选择策略""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.select_dynamic_segment_nuscenes(scene_info) + else: + # 原有的sekai方式 + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(total_lens-target_frames_compressed-1, max_condition_compressed) + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9 or total_lens <= 2*target_frames_compressed + 1: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) + + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + if dataset_type == 'spatialvid': + keyframe_original_idx.append(compressed_idx) # spatialvid直接使用compressed_idx + elif dataset_type == 'openx' or 'sekai': # 🔧 新增openx处理 + keyframe_original_idx.append(compressed_idx * 4) # openx使用4倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + # 修改create_pose_embeddings方法 + def create_pose_embeddings(self, cam_data, segment_info, dataset_type, scene_info=None): + """🔧 根据数据集类型创建pose embeddings""" + if dataset_type == 'nuscenes' and scene_info is not None: + return self.create_nuscenes_pose_embeddings_framepack(scene_info, segment_info) + elif dataset_type == 'spatialvid': # 🔧 新增spatialvid处理 + return self.create_spatialvid_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'sekai': + return self.create_sekai_pose_embeddings(cam_data, segment_info) + elif dataset_type == 'openx': # 🔧 新增openx处理 + return self.create_openx_pose_embeddings(cam_data, segment_info) + + def __getitem__(self, index): + while True: + try: + # 根据权重随机选择场景 + scene_idx = np.random.choice(len(self.scene_dirs), p=self.sampling_probs) + scene_dir = self.scene_dirs[scene_idx] + dataset_info = self.dataset_info[scene_dir] + + dataset_name = dataset_info['name'] + dataset_type = dataset_info['type'] + + # 🔧 根据数据集类型加载数据 + scene_info = None + if dataset_type == 'nuscenes': + # NuScenes需要加载scene_info.json + scene_info_path = os.path.join(scene_dir, "scene_info.json") + if os.path.exists(scene_info_path): + with open(scene_info_path, 'r') as f: + scene_info = json.load(f) + + # NuScenes使用不同的编码文件名 + encoded_path = os.path.join(scene_dir, "encoded_video-480p.pth") + if not os.path.exists(encoded_path): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") # fallback + + encoded_data = torch.load(encoded_path, weights_only=True, map_location="cpu") + else: + # Sekai数据集 + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + encoded_data = torch.load(encoded_path, weights_only=False, map_location="cpu") + + full_latents = encoded_data['latents'] + if full_latents.shape[1] <= 10: + continue + cam_data = encoded_data.get('cam_emb', encoded_data) + + # 🔧 验证NuScenes的latent帧数 + if dataset_type == 'nuscenes' and scene_info is not None: + expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + if abs(actual_latent_frames - expected_latent_frames) > 2: + print(f"⚠️ NuScenes Latent帧数不匹配,跳过此样本") + continue + + # 使用数据集特定的段落选择策略 + segment_info = self.select_dynamic_segment(full_latents, dataset_type, scene_info) + if segment_info is None: + continue + + # 创建数据集特定的pose embeddings + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info, dataset_type, scene_info) + if all_camera_embeddings is None: + continue + + # 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + # 🔧 NuScenes返回的是直接的embedding,Sekai返回的是tensor + if isinstance(all_camera_embeddings, torch.Tensor): + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + else: + # NuScenes风格,直接就是最终的embedding + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], + "clean_latents": framepack_inputs['clean_latents'], + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # Camera数据 + "camera": camera_with_mask, + + # 其他数据 + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + # 元信息 + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, + "dataset_type": dataset_type, + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类为MoE版本""" + from diffsynth.models.wan_video_dit_moe import WanModelMoe + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) + new_model_classes.append(WanModelMoe) # 🔧 使用MoE版本 + print(f"✅ 替换了模型类: {name} -> WanModelMoe") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class MultiDatasetLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + # 🔧 MoE参数 + use_moe=False, + moe_config=None + ): + super().__init__() + self.use_moe = use_moe + self.moe_config = moe_config or {} + + replace_dit_model_in_manager() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加FramePack的clean_x_embedder + self.add_framepack_components() + if self.use_moe: + self.add_moe_components() + + # 🔧 添加camera编码器(wan_video_dit_moe.py已经包含MoE逻辑) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + # 🔧 简化:只添加传统camera编码器,MoE逻辑在wan_video_dit_moe.py中 + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + state_dict.pop("global_router.weight", None) + state_dict.pop("global_router.bias", None) + self.pipe.dit.load_state_dict(state_dict, strict=False) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 🔧 训练参数设置 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder", + "moe", "sekai_processor", "nuscenes_processor","openx_processor"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "multi_dataset_dynamic/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_moe_components(self): + """🔧 添加MoE相关组件 - 简化版,只为每个block添加MoE,全局processor在WanModelMoe中""" + if not hasattr(self.pipe.dit, 'moe_config'): + self.pipe.dit.moe_config = self.moe_config + print("✅ 添加了MoE配置到模型") + self.pipe.dit.top_k = self.moe_config.get("top_k", 1) + + # 为每个block添加MoE组件(modality processors已经在WanModelMoe中全局创建) + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + unified_dim = self.moe_config.get("unified_dim", 30) + num_experts = self.moe_config.get("num_experts", 4) + from diffsynth.models.wan_video_dit_moe import MultiModalMoE, ModalityProcessor + + self.pipe.dit.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) + self.pipe.dit.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) + self.pipe.dit.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理 + self.pipe.dit.global_router = nn.Linear(unified_dim, num_experts) + + for i, block in enumerate(self.pipe.dit.blocks): + # 只为每个block添加MoE网络 + block.moe = MultiModalMoE( + unified_dim=unified_dim, + output_dim=dim, + num_experts=self.moe_config.get("num_experts", 4), + top_k=self.moe_config.get("top_k", 2) + ) + + print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {self.moe_config.get('num_experts', 4)})") + + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 多数据集训练步骤""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + dataset_type = batch.get("dataset_type", ["sekai"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 准备输入数据 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: + latents = latents.unsqueeze(0) + + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 索引处理 + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # Camera embedding处理 + cam_emb = batch["camera"].to(self.device) + + # 🔧 根据数据集类型设置modality_inputs + if dataset_type == "sekai": + modality_inputs = {"sekai": cam_emb} + elif dataset_type == "spatialvid": # 🔧 spatialvid使用sekai processor + modality_inputs = {"sekai": cam_emb} # 注意:这里使用"sekai"键 + elif dataset_type == "nuscenes": + modality_inputs = {"nuscenes": cam_emb} + elif dataset_type == "openx": # 🔧 新增:openx使用独立的processor + modality_inputs = {"openx": cam_emb} + else: + modality_inputs = {"sekai": cam_emb} # 默认 + + camera_dropout_prob = 0.05 + if random.random() < camera_dropout_prob: + cam_emb = torch.zeros_like(cam_emb) + # 同时清空modality_inputs + for key in modality_inputs: + modality_inputs[key] = torch.zeros_like(modality_inputs[key]) + print(f"应用camera dropout for CFG training (dataset: {dataset_name}, type: {dataset_type})") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + noise_pred, specialization_loss = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, + modality_inputs=modality_inputs, # 🔧 传递多模态输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + # 🔧 计算总loss = 重建loss + MoE专业化loss + reconstruction_loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + reconstruction_loss = reconstruction_loss * self.pipe.scheduler.training_weight(timestep) + + # 🔧 添加MoE专业化loss(交叉熵损失) + specialization_loss_weight = self.moe_config.get("moe_loss_weight", 0.1) + total_loss = reconstruction_loss + specialization_loss_weight * specialization_loss + + print(f'\n loss info (step {self.global_step}):') + print(f' - diff loss: {reconstruction_loss.item():.6f}') + print(f' - MoE specification loss: {specialization_loss.item():.6f}') + print(f' - Expert loss weight: {specialization_loss_weight}') + print(f' - Total Loss: {total_loss.item():.6f}') + + # 🔧 显示预期的专家映射 + modality_to_expert = { + "sekai": 0, + "nuscenes": 1, + "openx": 2 + } + expected_expert = modality_to_expert.get(dataset_type, 0) + print(f' - current modality: {dataset_type} -> expected expert: {expected_expert}') + + return total_loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_origin_other_continue3.ckpt")) + print(f"Saved MoE model checkpoint: step{current_step}_origin.ckpt") + +def train_multi_dataset(args): + """训练支持多数据集MoE的模型""" + + # 🔧 数据集配置 + dataset_configs = [ + { + 'name': 'sekai-drone', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-drone'], + 'type': 'sekai', + 'weight': 0.7 + }, + { + 'name': 'sekai-walking', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/sekai-game-walking'], + 'type': 'sekai', + 'weight': 0.7 + }, + { + 'name': 'spatialvid', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/spatialvid'], + 'type': 'spatialvid', + 'weight': 1.0 + }, + { + 'name': 'nuscenes', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic'], + 'type': 'nuscenes', + 'weight': 7.0 + }, + { + 'name': 'openx-fractal', + 'paths': ['/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded'], + 'type': 'openx', + 'weight': 1.1 + } + ] + + dataset = MultiDatasetDynamicDataset( + dataset_configs, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + # 🔧 MoE配置 + moe_config = { + "unified_dim": args.unified_dim, # 新增 + "num_experts": args.moe_num_experts, + "top_k": args.moe_top_k, + "moe_loss_weight": args.moe_loss_weight, + "sekai_input_dim": 13, + "nuscenes_input_dim": 8, + "openx_input_dim": 13 + } + + model = MultiDatasetLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + use_moe=True, # 总是使用MoE + moe_config=moe_config + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train Multi-Dataset FramePack with MoE") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=200000) + parser.add_argument("--max_epochs", type=int, default=100000) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", default=False) + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step23000_origin_other_continue_con.ckpt") + + # 🔧 MoE参数 + parser.add_argument("--unified_dim", type=int, default=25, help="统一的中间维度") + parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量") + parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家") + parser.add_argument("--moe_loss_weight", type=float, default=0.1, help="MoE损失权重") + + args = parser.parse_args() + + print("🔧 多数据集MoE训练配置:") + print(f" - 使用wan_video_dit_moe.py作为模型") + print(f" - 统一维度: {args.unified_dim}") + print(f" - 专家数量: {args.moe_num_experts}") + print(f" - Top-K: {args.moe_top_k}") + print(f" - MoE损失权重: {args.moe_loss_weight}") + print(" - 数据集:") + print(" - sekai-game-drone (sekai模态)") + print(" - sekai-game-walking (sekai模态)") + print(" - spatialvid (使用sekai模态处理器)") + print(" - openx-fractal (使用sekai模态处理器)") + print(f" - nuscenes (nuscenes模态)") + + train_multi_dataset(args) \ No newline at end of file diff --git a/scripts/train_recam.py b/scripts/train_recam.py new file mode 100644 index 0000000000000000000000000000000000000000..7de769227b432c8a4aeb472d51f711419e355a5e --- /dev/null +++ b/scripts/train_recam.py @@ -0,0 +1,692 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil +import wandb +import pdb + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + while True: + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + break + except: + data_id += 1 + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".recam.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + print(f"Output: {pth_path}") + else: + print(f"File {pth_path} already exists, skipping.") + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch, condition_frames=10, target_frames=5): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".recam.pth" for i in self.path if os.path.exists(i + ".recam.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + self.steps_per_epoch = steps_per_epoch + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + + def get_relative_pose(self, cam_params): + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + + def __getitem__(self, index): + # Return: + # data['latents']: torch.Size([16, T_target + T_cond, H, W]) + # data['camera']: torch.Size([T_target, 12]) + # data['prompt_emb']["context"][0]: torch.Size([512, 4096]) + while True: + try: + data = {} + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path_tgt = self.path[data_id] + data_tgt = torch.load(path_tgt, weights_only=True, map_location="cpu") + + # load the condition latent (不同相机) + match = re.search(r'cam(\d+)', path_tgt) + tgt_idx = int(match.group(1)) + cond_idx = random.randint(1, 10) + while cond_idx == tgt_idx: + cond_idx = random.randint(1, 10) + path_cond = re.sub(r'cam(\d+)', f'cam{cond_idx:02}', path_tgt) + data_cond = torch.load(path_cond, weights_only=True, map_location="cpu") + + # 截取 target 与 condition 帧并按 [target | condition] 拼接 + lat_tgt_full = data_tgt['latents'] # [C, T, H, W] + lat_cond_full = data_cond['latents'] # [C, T, H, W] + T_tgt_avail = lat_tgt_full.shape[1] + T_cond_avail = lat_cond_full.shape[1] + tgt_len = min(self.target_frames, T_tgt_avail) + cond_len = min(self.condition_frames, T_cond_avail) + lat_tgt = lat_tgt_full[:, :tgt_len, ...] + lat_cond = lat_cond_full[:, :cond_len, ...] + data['latents'] = torch.cat((lat_tgt, lat_cond), dim=1) # [C, tgt_len+cond_len, H, W] + + data['prompt_emb'] = data_tgt['prompt_emb'] + data['image_emb'] = {} + # load the target trajectory -> 生成 target_len 帧的相机相对位姿嵌入 + base_path = path_tgt.rsplit('/', 2)[0] + tgt_camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") + with open(tgt_camera_path, 'r') as file: + cam_data = json.load(file) + + # 均匀采样 target_len 帧的时间索引(0~80) + cam_idx = np.linspace(0, 80, tgt_len, dtype=int).tolist() + + multiview_c2ws = [] + for view_idx in [cond_idx, tgt_idx]: + traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{view_idx:02d}"]) for idx in cam_idx] + # 统一为 4x4 c2w + c2ws = [] + for m in traj: + m = np.array(m) + if m.shape == (3, 4): + m = np.vstack([m, np.array([0, 0, 0, 1.0])]) + elif m.shape != (4, 4): + raise ValueError(f"Unexpected c2w shape: {m.shape}") + c2ws.append(m) + multiview_c2ws.append(c2ws) + + cond_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[0]] + tgt_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[1]] + relative_poses = [] + for i in range(len(tgt_cam_params)): + relative_pose = self.get_relative_pose([cond_cam_params[0], tgt_cam_params[i]]) + # 取目标相机在相对坐标下的 3x4 + relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) + + pose_embedding = torch.stack(relative_poses, dim=0) # [tgt_len, 3, 4] + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [tgt_len, 12] + data['camera'] = pose_embedding.to(torch.bfloat16) + break + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.path)) + return data + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + condition_frames=10, + target_frames=5, + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) # [B, C, T, H, W], T = tgt_len + cond_len + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + target_height, target_width = 40, 70 # 根据你的需求调整 + current_height, current_width = latents.shape[3], latents.shape[4] + + if current_height > target_height or current_width > target_width: + # 中心裁剪 + h_start = (current_height - target_height) // 2 + w_start = (current_width - target_width) // 2 + latents = latents[:, :, :, + h_start:h_start+target_height, + w_start:w_start+target_width] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) # [B, tgt_len, 12] after collate + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 仅保留 condition 段(后半段)为干净;target 段(前 tgt_len 帧)参与去噪训练 + tgt_len = self.target_frames + assert noisy_latents.shape[2] >= tgt_len, f"Latent T {noisy_latents.shape[2]} < target_frames {tgt_len}" + noisy_latents[:, :, tgt_len:, ...] = origin_latents[:, :, tgt_len:, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss (只计算 target 段) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss( + noise_pred[:, :, :tgt_len, ...].float(), + training_target[:, :, :tgt_len, ...].float() + ) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + wandb.log({"train_loss": loss.item()}) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/models/checkpoints" + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ReCamMaster") + parser.add_argument( + "--task", + type=str, + default="train", + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/MultiCamVideo-Dataset", + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=100, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=2, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=True, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default="cloud", + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + parser.add_argument( + "--condition_frames", + type=int, + default=20, + help="Number of condition frames (kept clean).", + ) + parser.add_argument( + "--target_frames", + type=int, + default=10, + help="Number of target frames (to be denoised).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, args.metadata_file_name), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + + if args.use_swanlab: + wandb.init( + project="recam", + name="recam", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_recam_future.py b/scripts/train_recam_future.py new file mode 100644 index 0000000000000000000000000000000000000000..7eeb85d2d6dc3c6da0e18bdd80591df9ecb1503d --- /dev/null +++ b/scripts/train_recam_future.py @@ -0,0 +1,703 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil +import wandb +import pdb + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + while True: + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + break + except: + data_id += 1 + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".recam.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + print(f"Output: {pth_path}") + else: + print(f"File {pth_path} already exists, skipping.") + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch, condition_frames=32, target_frames=32): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".recam.pth" for i in self.path if os.path.exists(i + ".recam.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + self.steps_per_epoch = steps_per_epoch + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + def get_relative_pose(self, pose_prev, pose_curr): + """计算相对位姿:从pose_prev到pose_curr""" + pose_prev_inv = np.linalg.inv(pose_prev) + relative_pose = pose_curr @ pose_prev_inv + return relative_pose + + def __getitem__(self, index): + while True: + try: + data = {} + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) + + # 加载单个相机的数据 + path = self.path[data_id] + video_data = torch.load(path, weights_only=True, map_location="cpu") + + # 获取视频latents + full_latents = video_data['latents'] # [C, T, H, W] + total_frames = full_latents.shape[1] + + # 检查是否有足够的帧数 + required_frames = self.condition_frames + self.target_frames + if total_frames < required_frames: + continue + + # 随机选择起始位置 + max_start = total_frames - required_frames + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + + # 提取condition和target段 + condition_latents = full_latents[:, start_frame:start_frame+self.condition_frames, :, :] + target_latents = full_latents[:, start_frame+self.condition_frames:start_frame+self.condition_frames+self.target_frames, :, :] + + # 拼接latents [condition, target] - 注意:训练时condition帧在前,target帧在后 + data['latents'] = torch.cat([condition_latents, target_latents], dim=1) + + data['prompt_emb'] = video_data['prompt_emb'] + data['image_emb'] = video_data.get('image_emb', {}) + + # 加载相机轨迹数据,生成时序相对位姿 + base_path = path.rsplit('/', 2)[0] + camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") + + if not os.path.exists(camera_path): + # 如果没有相机数据,生成零向量 - 只为target帧生成 + pose_embedding = torch.zeros(self.target_frames, 12, dtype=torch.bfloat16) + else: + with open(camera_path, 'r') as file: + cam_data = json.load(file) + + # 提取相机路径(使用相同相机的不同时间点) + match = re.search(r'cam(\d+)', path) + cam_idx = int(match.group(1)) if match else 1 + + # 为target帧生成相对位姿 + relative_poses = [] + + # 计算每个target帧相对于condition最后一帧的位姿 + condition_end_frame_idx = start_frame + self.condition_frames - 1 + + # 获取reference pose(condition段的最后一帧) + if f"frame{condition_end_frame_idx}" in cam_data and f"cam{cam_idx:02d}" in cam_data[f"frame{condition_end_frame_idx}"]: + reference_matrix_str = cam_data[f"frame{condition_end_frame_idx}"][f"cam{cam_idx:02d}"] + reference_pose = self.parse_matrix(reference_matrix_str) + if reference_pose.shape == (3, 4): + reference_pose = np.vstack([reference_pose, np.array([0, 0, 0, 1.0])]) + else: + reference_pose = np.eye(4, dtype=np.float32) + + # 🔧 修复:为每个target帧计算相对位姿 + for i in range(self.target_frames): + target_frame_idx = start_frame + self.condition_frames + i + + if f"frame{target_frame_idx}" in cam_data and f"cam{cam_idx:02d}" in cam_data[f"frame{target_frame_idx}"]: + target_matrix_str = cam_data[f"frame{target_frame_idx}"][f"cam{cam_idx:02d}"] + target_pose = self.parse_matrix(target_matrix_str) + if target_pose.shape == (3, 4): + target_pose = np.vstack([target_pose, np.array([0, 0, 0, 1.0])]) + + # 🔧 修复:正确调用get_relative_pose方法 + relative_pose = self.get_relative_pose(reference_pose, target_pose) + relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行 + else: + # 如果没有对应帧的数据,使用单位矩阵 + relative_poses.append(torch.as_tensor(np.eye(3, 4, dtype=np.float32))) + + pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4] + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12] + + data['camera'] = pose_embedding.to(torch.bfloat16) + break + + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.path)) + + return data + + def __len__(self): + return self.steps_per_epoch + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None, + condition_frames=10, + target_frames=5, + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + self.condition_frames = int(condition_frames) + self.target_frames = int(target_frames) + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) # [B, C, T, H, W], T = condition_frames + target_frames + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + # target_height, target_width = 40, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) # [B, target_frames, 12] - 只有target帧的pose + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 修复:condition段在前,保持clean;target段在后,参与去噪训练 + cond_len = self.condition_frames + noisy_latents[:, :, :cond_len, ...] = origin_latents[:, :, :cond_len, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss (只对target段计算loss) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 🔧 修复:只对target段(后半部分)计算loss + target_noise_pred = noise_pred[:, :, cond_len:, ...] + target_training_target = training_target[:, :, cond_len:, ...] + + loss = torch.nn.functional.mse_loss( + target_noise_pred.float(), + target_training_target.float() + ) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + wandb.log({ + "train_loss": loss.item(), + "condition_frames": cond_len, + "target_frames": self.target_frames, + }) + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/recam_future" + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ReCamMaster") + parser.add_argument( + "--task", + type=str, + default="train", + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/MultiCamVideo-Dataset", + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=1000, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=2, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=True, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default="cloud", + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + parser.add_argument( + "--condition_frames", + type=int, + default=10, + help="Number of condition frames (kept clean).", + ) + parser.add_argument( + "--target_frames", + type=int, + default=10, + help="Number of target frames (to be denoised).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, args.metadata_file_name), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + condition_frames=args.condition_frames, + target_frames=args.target_frames, + ) + + if args.use_swanlab: + wandb.init( + project="recam", + name="recam", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_recammaster.py b/scripts/train_recammaster.py new file mode 100644 index 0000000000000000000000000000000000000000..e338c2186935dfd8bfe15b131a8e0d715110e220 --- /dev/null +++ b/scripts/train_recammaster.py @@ -0,0 +1,640 @@ +import copy +import os +import re +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoReCamMasterPipeline, ModelManager, load_state_dict +import torchvision +from PIL import Image +import numpy as np +import random +import json +import torch.nn as nn +import torch.nn.functional as F +import shutil +import wandb +import pdb + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + while True: + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + break + except: + data_id += 1 + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + + self.pipe.device = self.device + if video is not None: + pth_path = path + ".recam.pth" + if not os.path.exists(pth_path): + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + torch.save(data, pth_path) + print(f"Output: {pth_path}") + else: + print(f"File {pth_path} already exists, skipping.") + +class Camera(object): + def __init__(self, c2w): + c2w_mat = np.array(c2w).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".recam.pth" for i in self.path if os.path.exists(i + ".recam.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + self.steps_per_epoch = steps_per_epoch + + + def parse_matrix(self, matrix_str): + rows = matrix_str.strip().split('] [') + matrix = [] + for row in rows: + row = row.replace('[', '').replace(']', '') + matrix.append(list(map(float, row.split()))) + return np.array(matrix) + + + def get_relative_pose(self, cam_params): + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + + def __getitem__(self, index): + # Return: + # data['latents']: torch.Size([16, 21*2, 60, 104]) + # data['camera']: torch.Size([21, 3, 4]) + # data['prompt_emb']["context"][0]: torch.Size([512, 4096]) + while True: + try: + data = {} + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path_tgt = self.path[data_id] + data_tgt = torch.load(path_tgt, weights_only=True, map_location="cpu") + + # load the condition latent + match = re.search(r'cam(\d+)', path_tgt) + tgt_idx = int(match.group(1)) + cond_idx = random.randint(1, 10) + while cond_idx == tgt_idx: + cond_idx = random.randint(1, 10) + path_cond = re.sub(r'cam(\d+)', f'cam{cond_idx:02}', path_tgt) + data_cond = torch.load(path_cond, weights_only=True, map_location="cpu") + data['latents'] = torch.cat((data_tgt['latents'],data_cond['latents']),dim=1) + data['prompt_emb'] = data_tgt['prompt_emb'] + data['image_emb'] = {} + + # load the target trajectory + base_path = path_tgt.rsplit('/', 2)[0] + tgt_camera_path = os.path.join(base_path, "cameras", "camera_extrinsics.json") + with open(tgt_camera_path, 'r') as file: + cam_data = json.load(file) + multiview_c2ws = [] + cam_idx = list(range(81))[::4] + for view_idx in [cond_idx, tgt_idx]: + traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{view_idx:02d}"]) for idx in cam_idx] + traj = np.stack(traj).transpose(0, 2, 1) + c2ws = [] + for c2w in traj: + c2w = c2w[:, [1, 2, 0, 3]] + c2w[:3, 1] *= -1. + c2w[:3, 3] /= 100 + c2ws.append(c2w) + multiview_c2ws.append(c2ws) + cond_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[0]] + tgt_cam_params = [Camera(cam_param) for cam_param in multiview_c2ws[1]] + relative_poses = [] + for i in range(len(tgt_cam_params)): + relative_pose = self.get_relative_pose([cond_cam_params[0], tgt_cam_params[i]]) + relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) + pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + data['camera'] = pose_embedding.to(torch.bfloat16) + break + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + index = random.randrange(len(self.path)) + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + dim=self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + + self.freeze_parameters() + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + print(f"Trainable: {name}") + for param in module.parameters(): + param.requires_grad = True + + trainable_params = 0 + seen_params = set() + for name, module in self.pipe.denoising_model().named_modules(): + for param in module.parameters(): + if param.requires_grad and param not in seen_params: + trainable_params += param.numel() + seen_params.add(param) + print(f"Total number of trainable parameters: {trainable_params}") + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + # pdb.set_trace() + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + tgt_latent_len = noisy_latents.shape[2] // 2 + noisy_latents[:, :, tgt_latent_len:, ...] = origin_latents[:, :, tgt_latent_len:, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred[:, :, :tgt_latent_len, ...].float(), training_target[:, :, :tgt_latent_len, ...].float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + #self.log("train_loss", loss, prog_bar=True, logger=True) + wandb.log({"train_loss": loss.item()}) + + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/nus_models/checkpoints" + print(f"Checkpoint directory: {checkpoint_dir}") + current_step = self.global_step + print(f"Current step: {current_step}") + + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ReCamMaster") + parser.add_argument( + "--task", + type=str, + default="train", + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="/share_zhuyixuan05/zhuyixuan05/MultiCamVideo-Dataset", + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=400, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=3, + help="Number of epochs.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="deepspeed_stage_1", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--use_swanlab", + default=True, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default="cloud", + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--metadata_file_name", + type=str, + default="metadata.csv", + ) + parser.add_argument( + "--resume_ckpt_path", + type=str, + default=None, + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, args.metadata_file_name), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=2, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=2, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + if args.use_swanlab: + wandb.init( + project="recam", + name="recam", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True) + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) \ No newline at end of file diff --git a/scripts/train_rlbench_noise.py b/scripts/train_rlbench_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8f90c7bffcbfbf988c6706643910c8acebef4a --- /dev/null +++ b/scripts/train_rlbench_noise.py @@ -0,0 +1,596 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + # traj_coord: list of coordinate changes, each element is a [dx, dy, dz] + # tarj_angle: list of position angle changes, each element is an angle in range (0, 180) + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + # traj_rot_angle: list of rotation angle changes, each element is an angle in range (0, 180) + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + + 参数: + pose_a: 相机A的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机A坐标系的变换 (world → camera A) + pose_b: 相机B的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机B坐标系的变换 (world → camera B) + use_torch: 是否使用PyTorch进行计算,默认使用NumPy + + 返回: + relative_pose: 相对位姿矩阵 (4x4),表示从相机A坐标系到相机B坐标系的变换 + (camera A → camera B) + """ + # 检查输入形状 + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + # 确保输入是PyTorch张量 + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + # 确保输入是NumPy数组 + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +class DynamicRLBenchDataset(torch.utils.data.Dataset): + """支持动态历史长度的NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + # print(f"Found {len(self.scene_dirs)} scenes with encoded data") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 预处理设置 + # self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(height, width)), + # v2.Resize(size=(height, width), antialias=True), + # v2.ToTensor(), + # v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + # ]) + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment(self, full_latents): + """动态选择条件帧和目标帧 - 修正版本处理VAE时间压缩""" + total_lens = full_latents.shape[1] + # print(f"原始总帧数: {total_frames}, 压缩后: {compressed_total_frames}") + # print(f"原始关键帧: {keyframe_indices[:5]}..., 压缩后: {compressed_keyframe_indices[:5]}...") + + # 随机选择条件帧长度(基于压缩后的帧数) + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + print('ratio:',ratio) + if ratio<0.15: + condition_frames_compressed = 1 + elif 0.15<=ratio<0.3: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = start_frame_compressed + + # 🔧 找到对应的原始关键帧索引用于pose查找 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed,target_end_compressed): + keyframe_original_idx.append(compressed_idx*4) + + + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'keyframe_original_idx': keyframe_original_idx, # 原始关键帧索引 + + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + + def create_pose_embeddings(self, cam_data, segment_info): + """创建pose embeddings - 修正版本,确保与latent帧数对齐""" + cam_data_seq = cam_data # 300 * 4 * 4 + keyframe_original_idx = segment_info['keyframe_original_idx'] + # target_keyframe_indices = segment_info['target_keyframe_indices'] + + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + # frame_range = cam_data_seq[start_frame:end_frame] + + relative_cams = [] + for idx in keyframe_original_idx: + cam_prev = cam_data_seq[idx] + # cam_next = cam_data_seq[idx+4] + # print('cam_prev:',cam_prev) + # print('idx:',idx) + # assert False + #relative_cam = compute_relative_pose(cam_prev,cam_next) + # print('relative_cam:',relative_cam) + # assert False + relative_cams.append(torch.as_tensor(cam_prev)) + + pose_embedding = torch.stack(relative_cams, dim=0) + # pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + # print(pose_embedding.shape) + # assert False + # print() + # traj_pos_coord_full, tarj_pos_angle_full = get_traj_position_change(cam_data_seq, self.time_compression_ratio) + # traj_rot_angle_full = get_traj_rotation_change(cam_data_seq, self.time_compression_ratio) + + # motion_emb = + + return { + 'camera': pose_embedding + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + # with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + # scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + # expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + # print(f"场景 {os.path.basename(scene_dir)}: 原始帧数={scene_info['total_frames']}, " + # f"预期latent帧数={expected_latent_frames}, 实际latent帧数={actual_latent_frames}") + + # if abs(actual_latent_frames - expected_latent_frames) > 2: # 允许小的舍入误差 + # print(f"⚠️ Latent帧数不匹配,跳过此样本") + # continue + + # 动态选择段落 + segment_info = self.select_dynamic_segment(full_latents) + if segment_info is None: + continue + # print("segment_info:",segment_info) + # 创建pose embeddings + pose_data = self.create_pose_embeddings(cam_data, segment_info) + if pose_data is None: + continue + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + + pose_data["camera"] = torch.cat([pose_data["camera"], mask], dim=1) + # print(pose_data['camera'].shape) + # assert False + # 🔧 使用压缩后的索引提取latent段落 + start_frame = segment_info['start_frame'] # 已经是压缩后的索引 + condition_frames = segment_info['condition_frames'] # 已经是压缩后的帧数 + target_frames = segment_info['target_frames'] # 已经是压缩后的帧数 + + # print(f"提取latent段落: start={start_frame}, condition={condition_frames}, target={target_frames}") + # print(f"Full latents shape: {full_latents.shape}") + + # # 确保索引不越界 + # if start_frame + condition_frames + target_frames > full_latents.shape[1]: + # print(f"索引越界,跳过: {start_frame + condition_frames + target_frames} > {full_latents.shape[1]}") + # continue + + condition_latents = full_latents[:, start_frame:start_frame+condition_frames, :, :] + + + + target_latents = full_latents[:, start_frame+condition_frames:start_frame+condition_frames+target_frames, :, :] + + # print(f"Condition latents shape: {condition_latents.shape}") + # print(f"Target latents shape: {target_latents.shape}") + + # 拼接latents [condition, target] + combined_latents = torch.cat([condition_latents, target_latents], dim=1) + + result = { + "latents": combined_latents, + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "camera": pose_data['camera'], + + "condition_frames": condition_frames, # 压缩后的帧数 + "target_frames": target_frames, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(30 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + # 获取动态长度信息(这些已经是压缩后的帧数) + condition_frames = batch["condition_frames"][0].item() # 压缩后的condition长度 + target_frames = batch["target_frames"][0].item() # 压缩后的target长度 + + # 🔧 获取原始帧数用于日志记录 + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + print("condition_frames:",batch["condition_frames"]) + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + # Data + latents = batch["latents"].to(self.device) + # print(f"压缩后condition帧数: {condition_frames}, target帧数: {target_frames}") + # print(f"原始condition帧数: {original_condition_frames}, target帧数: {original_target_frames}") + # print(f"Latents shape: {latents.shape}") + + # 裁剪空间尺寸以节省内存 + # target_height, target_width = 50, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + # print(f"裁剪后latents shape: {latents.shape}") + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + noisy_condition_latents = copy.deepcopy(latents[:, :, :condition_frames, ...]) + is_add_noise = random.random() + if is_add_noise > 0.2: + # add noise to condition + noise_cond = torch.randn_like(latents[:, :, :condition_frames, ...]) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(latents[:, :, :condition_frames, ...], noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 关键:使用压缩后的condition长度 + # condition部分保持clean,只对target部分加噪 + noisy_latents[:, :, :condition_frames, ...] = noisy_condition_latents #origin_latents[:, :, :condition_frames, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + # print(f"targe尺寸: {training_target.shape}") + # 预测噪声 + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + # print(f"pred尺寸: {training_target.shape}") + # 🔧 只对target部分计算loss(使用压缩后的索引) + target_noise_pred = noise_pred[:, :, condition_frames:condition_frames+target_frames, ...] + target_training_target = training_target[:, :, condition_frames:condition_frames+target_frames, ...] + + loss = torch.nn.functional.mse_loss(target_noise_pred.float(), target_training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:',loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, # 压缩后的帧数000 + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, # 原始帧数 + "target_frames_original": original_target_frames, + "total_frames_compressed": condition_frames + target_frames, + "total_frames_original": original_condition_frames + original_target_frames, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/RLBench-train" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_dynamic.ckpt")) + print(f"Saved dynamic model checkpoint: step{current_step}_dynamic.ckpt") + +def train_dynamic(args): + """训练支持动态历史长度的模型""" + dataset = DynamicRLBenchDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-recam", + name=f"dynamic-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=2000) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_sekai_dynamic.py b/scripts/train_sekai_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..92230d2508be3da4660f2ae9d6adf3b44e79cc7e --- /dev/null +++ b/scripts/train_sekai_dynamic.py @@ -0,0 +1,583 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + # traj_coord: list of coordinate changes, each element is a [dx, dy, dz] + # tarj_angle: list of position angle changes, each element is an angle in range (0, 180) + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + # traj_rot_angle: list of rotation angle changes, each element is an angle in range (0, 180) + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + + 参数: + pose_a: 相机A的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机A坐标系的变换 (world → camera A) + pose_b: 相机B的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机B坐标系的变换 (world → camera B) + use_torch: 是否使用PyTorch进行计算,默认使用NumPy + + 返回: + relative_pose: 相对位姿矩阵 (4x4),表示从相机A坐标系到相机B坐标系的变换 + (camera A → camera B) + """ + # 检查输入形状 + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + # 确保输入是PyTorch张量 + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + # 确保输入是NumPy数组 + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +class DynamicSekaiDataset(torch.utils.data.Dataset): + """支持动态历史长度的NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + # print(f"Found {len(self.scene_dirs)} scenes with encoded data") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 预处理设置 + # self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(height, width)), + # v2.Resize(size=(height, width), antialias=True), + # v2.ToTensor(), + # v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + # ]) + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment(self, full_latents): + """动态选择条件帧和目标帧 - 修正版本处理VAE时间压缩""" + total_lens = full_latents.shape[1] + # print(f"原始总帧数: {total_frames}, 压缩后: {compressed_total_frames}") + # print(f"原始关键帧: {keyframe_indices[:5]}..., 压缩后: {compressed_keyframe_indices[:5]}...") + + # 随机选择条件帧长度(基于压缩后的帧数) + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + print('ratio:',ratio) + if ratio<0.15: + condition_frames_compressed = 1 + elif 0.15<=ratio<0.3: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = start_frame_compressed + + # 🔧 找到对应的原始关键帧索引用于pose查找 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed,target_end_compressed): + keyframe_original_idx.append(compressed_idx*4) + + + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'keyframe_original_idx': keyframe_original_idx, # 原始关键帧索引 + + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + + def create_pose_embeddings(self, cam_data, segment_info): + """创建pose embeddings - 修正版本,确保与latent帧数对齐""" + cam_data_seq = cam_data['extrinsic'] # 300 * 4 * 4 + keyframe_original_idx = segment_info['keyframe_original_idx'] + # target_keyframe_indices = segment_info['target_keyframe_indices'] + + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + # frame_range = cam_data_seq[start_frame:end_frame] + + relative_cams = [] + for idx in keyframe_original_idx: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx+4] + # print('cam_prev:',cam_prev) + # print('idx:',idx) + # assert False + relative_cam = compute_relative_pose(cam_prev,cam_next) + # print('relative_cam:',relative_cam) + # assert False + relative_cams.append(torch.as_tensor(relative_cam[:3,:])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + # print(pose_embedding.shape) + # assert False + # print() + # traj_pos_coord_full, tarj_pos_angle_full = get_traj_position_change(cam_data_seq, self.time_compression_ratio) + # traj_rot_angle_full = get_traj_rotation_change(cam_data_seq, self.time_compression_ratio) + + # motion_emb = + + return { + 'camera': pose_embedding + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + # with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + # scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + # expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + # print(f"场景 {os.path.basename(scene_dir)}: 原始帧数={scene_info['total_frames']}, " + # f"预期latent帧数={expected_latent_frames}, 实际latent帧数={actual_latent_frames}") + + # if abs(actual_latent_frames - expected_latent_frames) > 2: # 允许小的舍入误差 + # print(f"⚠️ Latent帧数不匹配,跳过此样本") + # continue + + # 动态选择段落 + segment_info = self.select_dynamic_segment(full_latents) + if segment_info is None: + continue + # print("segment_info:",segment_info) + # 创建pose embeddings + pose_data = self.create_pose_embeddings(cam_data, segment_info) + if pose_data is None: + continue + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + + pose_data["camera"] = torch.cat([pose_data["camera"], mask], dim=1) + # print(pose_data['camera'].shape) + # assert False + # 🔧 使用压缩后的索引提取latent段落 + start_frame = segment_info['start_frame'] # 已经是压缩后的索引 + condition_frames = segment_info['condition_frames'] # 已经是压缩后的帧数 + target_frames = segment_info['target_frames'] # 已经是压缩后的帧数 + + # print(f"提取latent段落: start={start_frame}, condition={condition_frames}, target={target_frames}") + # print(f"Full latents shape: {full_latents.shape}") + + # # 确保索引不越界 + # if start_frame + condition_frames + target_frames > full_latents.shape[1]: + # print(f"索引越界,跳过: {start_frame + condition_frames + target_frames} > {full_latents.shape[1]}") + # continue + + condition_latents = full_latents[:, start_frame:start_frame+condition_frames, :, :] + target_latents = full_latents[:, start_frame+condition_frames:start_frame+condition_frames+target_frames, :, :] + + # print(f"Condition latents shape: {condition_latents.shape}") + # print(f"Target latents shape: {target_latents.shape}") + + # 拼接latents [condition, target] + combined_latents = torch.cat([condition_latents, target_latents], dim=1) + + result = { + "latents": combined_latents, + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "camera": pose_data['camera'], + + "condition_frames": condition_frames, # 压缩后的帧数 + "target_frames": target_frames, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + # 获取动态长度信息(这些已经是压缩后的帧数) + condition_frames = batch["condition_frames"][0].item() # 压缩后的condition长度 + target_frames = batch["target_frames"][0].item() # 压缩后的target长度 + + # 🔧 获取原始帧数用于日志记录 + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + # Data + latents = batch["latents"].to(self.device) + # print(f"压缩后condition帧数: {condition_frames}, target帧数: {target_frames}") + # print(f"原始condition帧数: {original_condition_frames}, target帧数: {original_target_frames}") + # print(f"Latents shape: {latents.shape}") + + # 裁剪空间尺寸以节省内存 + # target_height, target_width = 50, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + # print(f"裁剪后latents shape: {latents.shape}") + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 关键:使用压缩后的condition长度 + # condition部分保持clean,只对target部分加噪 + noisy_latents[:, :, :condition_frames, ...] = origin_latents[:, :, :condition_frames, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + # print(f"targe尺寸: {training_target.shape}") + # 预测噪声 + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + # print(f"pred尺寸: {training_target.shape}") + # 🔧 只对target部分计算loss(使用压缩后的索引) + target_noise_pred = noise_pred[:, :, condition_frames:condition_frames+target_frames, ...] + target_training_target = training_target[:, :, condition_frames:condition_frames+target_frames, ...] + + loss = torch.nn.functional.mse_loss(target_noise_pred.float(), target_training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:',loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, # 压缩后的帧数000 + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, # 原始帧数 + "target_frames_original": original_target_frames, + "total_frames_compressed": condition_frames + target_frames, + "total_frames_original": original_condition_frames + original_target_frames, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/sekai_dynamic_2" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_dynamic.ckpt")) + print(f"Saved dynamic model checkpoint: step{current_step}_dynamic.ckpt") + +def train_dynamic(args): + """训练支持动态历史长度的模型""" + dataset = DynamicSekaiDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-recam", + name=f"dynamic-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=6000) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_sekai_framepack.py b/scripts/train_sekai_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c7019c3f2cba42c02c27b201a7d3f6b3d71341 --- /dev/null +++ b/scripts/train_sekai_framepack.py @@ -0,0 +1,697 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + """ + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +class DynamicSekaiDataset(torch.utils.data.Dataset): + """支持FramePack机制的动态历史长度数据集 - 支持多个数据集""" + + def __init__(self, base_paths, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + # 🔧 修改:支持多个数据集路径 + if isinstance(base_paths, str): + base_paths = [base_paths] # 如果是单个路径,转换为列表 + + self.base_paths = base_paths + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 🔧 修改:查找所有数据集中的处理好的场景 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景属于哪个数据集 + + for base_path in self.base_paths: + dataset_name = os.path.basename(base_path) # 获取数据集名称 + print(f"🔧 扫描数据集: {dataset_name} ({base_path})") + + if os.path.exists(base_path): + dataset_scenes = [] + for item in os.listdir(base_path): + scene_dir = os.path.join(base_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = dataset_name + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + else: + print(f" ⚠️ 路径不存在: {base_path}") + + print(f"🔧 总共找到 {len(self.scene_dirs)} 个场景") + for dataset_name in set(self.dataset_info.values()): + count = sum(1 for v in self.dataset_info.values() if v == dataset_name) + print(f" - {dataset_name}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + def select_dynamic_segment_framepack(self, full_latents): + """🔧 FramePack风格的动态选择条件帧和目标帧 - 修正版,考虑实际condition长度""" + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 修正:FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 🔧 修正:根据实际的condition_frames_compressed生成索引 + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2-1) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed > 3: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16-3) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + def create_pose_embeddings(self, cam_data, segment_info): + """🔧 创建pose embeddings - 为所有帧(condition + target)提取camera信息,支持0填充""" + cam_data_seq = cam_data['extrinsic'] # 300 * 4 * 4 + + # 🔧 修正:为所有帧(condition + target)计算camera embedding + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,正确处理空索引""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] # 注意维度顺序 + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def __getitem__(self, index): + while True: + try: + # 🔧 修改:随机选择一个场景(从所有数据集中) + scene_dir = random.choice(self.scene_dirs) + dataset_name = self.dataset_info[scene_dir] # 获取该场景所属的数据集 + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + + # 🔧 使用FramePack风格的段落选择 + segment_info = self.select_dynamic_segment_framepack(full_latents) + if segment_info is None: + continue + + # 🔧 修正:为所有帧创建pose embeddings(不带mask) + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) + if all_camera_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入(在这里处理0填充) + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 简化:像train_sekai_walking一样处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 # condition帧标记为1 + mask = mask.view(-1, 1) + + # 添加mask到camera embeddings + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 - 现在都有固定长度 + "latents": framepack_inputs['latents'], # 主要预测目标 + "clean_latents": framepack_inputs['clean_latents'], # 条件帧(2帧) + "clean_latents_2x": framepack_inputs['clean_latents_2x'], # 2x条件帧(2帧,不足用0填充) + "clean_latents_4x": framepack_inputs['clean_latents_4x'], # 4x条件帧(16帧,不足用0填充) + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], # 固定长度 + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], # 固定长度 + + # 🔧 简化:直接传递带mask的camera embeddings + "camera": camera_with_mask, # 所有帧的camera embeddings(带mask) + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, # 🔧 新增:记录数据集名称 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder - 参考hunyuan_video_packed.py + self.add_framepack_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件 - 参考hunyuan_video_packed.py""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤 - 修正维度处理""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + # 🔧 准备FramePack风格的输入 - 确保有batch维度 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 条件输入(处理空张量和维度) + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 🔧 索引(处理空张量) + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # 🔧 简化:直接使用带mask的camera embeddings + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 # 10%概率丢弃camera条件 + if random.random() < camera_dropout_prob: + # 创建零camera embedding + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: # 80%概率添加噪声 + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, # 🔧 简化:直接传递带mask的camera embeddings + # 🔧 FramePack风格的条件输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss(现在noise_pred只包含预测目标,不包含条件部分) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:', loss) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_framepack.ckpt")) + print(f"Saved FramePack model checkpoint: step{current_step}_framepack.ckpt") + +def train_dynamic(args): + """训练支持FramePack机制的动态历史长度模型 - 支持多数据集""" + # 🔧 修改:支持多个数据集路径 + dataset_paths = [ + "/share_zhuyixuan05/zhuyixuan05/sekai-game-drone", + "/share_zhuyixuan05/zhuyixuan05/sekai-game-walking" + ] + + dataset = DynamicSekaiDataset( + dataset_paths, # 🔧 传入多个数据集路径 + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + # wandb.init( + # project="sekai-multi-dataset-framepack-recam", # 🔧 修改项目名称 + # name=f"multi-dataset-framepack-{args.min_condition_frames}-{args.max_condition_frames}", + # ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + logger=False, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train FramePack Dynamic ReCamMaster with Multiple Datasets") + # 🔧 修改:dataset_path参数现在在代码中硬编码,但保留以便兼容 + parser.add_argument("--dataset_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking", + help="主数据集路径(实际会使用代码中的多数据集配置)") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=8000) + parser.add_argument("--max_epochs", type=int, default=3000) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step9144_framepack.ckpt") + + args = parser.parse_args() + + print("🔧 使用多数据集训练:") + print(" - /share_zhuyixuan05/zhuyixuan05/sekai-game-drone") + print(" - /share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_sekai_framepack_4.py b/scripts/train_sekai_framepack_4.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6067e75c6124846fca86d1c9a3cc5a7bebc0f6 --- /dev/null +++ b/scripts/train_sekai_framepack_4.py @@ -0,0 +1,735 @@ +#这个版本是condition有4帧的 +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + """ + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + +class DynamicSekaiDataset(torch.utils.data.Dataset): + """支持FramePack机制的动态历史长度数据集 - 支持多个数据集""" + + def __init__(self, base_paths, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + # 🔧 修改:支持多个数据集路径 + if isinstance(base_paths, str): + base_paths = [base_paths] # 如果是单个路径,转换为列表 + + self.base_paths = base_paths + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 🔧 修改:查找所有数据集中的处理好的场景 + self.scene_dirs = [] + self.dataset_info = {} # 记录每个场景属于哪个数据集 + + for base_path in self.base_paths: + dataset_name = os.path.basename(base_path) # 获取数据集名称 + print(f"🔧 扫描数据集: {dataset_name} ({base_path})") + + if os.path.exists(base_path): + dataset_scenes = [] + for item in os.listdir(base_path): + scene_dir = os.path.join(base_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + dataset_scenes.append(scene_dir) + self.dataset_info[scene_dir] = dataset_name + + print(f" ✅ 找到 {len(dataset_scenes)} 个场景") + else: + print(f" ⚠️ 路径不存在: {base_path}") + + print(f"🔧 总共找到 {len(self.scene_dirs)} 个场景") + for dataset_name in set(self.dataset_info.values()): + count = sum(1 for v in self.dataset_info.values() if v == dataset_name) + print(f" - {dataset_name}: {count} 个场景") + + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + def select_dynamic_segment_framepack(self, full_latents): + """🔧 FramePack风格的动态选择条件帧和目标帧 - 修正版,考虑实际condition长度""" + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + return None + + start_frame_compressed = random.randint(0, total_lens - min_required_frames - 1) + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 修正:FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 🔧 修改:1x帧改为起始4帧 + 最后1帧 + # 起始4帧(不足的用0填充) + clean_latent_start_indices = torch.arange(start_frame_compressed, min(start_frame_compressed + 4, condition_end_compressed)) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_start_indices, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start, condition_end_compressed) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start, condition_end_compressed) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx * 4) + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - 修正版,支持起始4帧+最后1帧""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] # 注意维度顺序 + + # 🔧 1x条件帧:起始4帧 + 最后1帧,固定长度为5帧 + clean_latent_indices = segment_info['clean_latent_indices'] + + # 创建固定长度5的latents,初始化为0 + clean_latents = torch.zeros(B, C, 5, H, W, dtype=full_latents.dtype) + clean_latent_indices_final = torch.full((5,), -1, dtype=torch.long) # -1表示padding + + # 填充真实的clean latents + if len(clean_latent_indices) > 0: + # 获取真实的latent数据 + actual_clean_latents = full_latents[:, :, clean_latent_indices, :, :] + actual_frames = actual_clean_latents.shape[2] + + if actual_frames <= 5: + # 如果实际帧数不超过5,从前往后填充 + clean_latents[:, :, :actual_frames, :, :] = actual_clean_latents + clean_latent_indices_final[:actual_frames] = clean_latent_indices + else: + # 如果超过5帧,取前4帧+最后1帧 + clean_latents[:, :, :4, :, :] = actual_clean_latents[:, :, :4, :, :] + clean_latents[:, :, 4:5, :, :] = actual_clean_latents[:, :, -1:, :, :] + clean_latent_indices_final[:4] = clean_latent_indices[:4] + clean_latent_indices_final[4:5] = clean_latent_indices[-1:] + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, # 现在是5帧:起始4帧+最后1帧 + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': clean_latent_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_2x_indices': clean_latent_2x_indices_final, # 🔧 使用真实索引(含-1填充) + 'clean_latent_4x_indices': clean_latent_4x_indices_final, # 🔧 使用真实索引(含-1填充) + } + + def create_pose_embeddings(self, cam_data, segment_info): + """🔧 创建pose embeddings - 为所有帧(condition + target)提取camera信息,支持起始4帧+最后1帧""" + cam_data_seq = cam_data['extrinsic'] # 300 * 4 * 4 + + # 🔧 修正:为所有帧(condition + target)计算camera embedding + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + + # 为所有帧计算相对pose + all_keyframe_indices = [] + for compressed_idx in range(segment_info['start_frame'], segment_info['target_range'][1]): + all_keyframe_indices.append(compressed_idx * 4) + + relative_cams = [] + for idx in all_keyframe_indices: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 4] + relative_cam = compute_relative_pose(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def __getitem__(self, index): + while True: + try: + # 🔧 修改:随机选择一个场景(从所有数据集中) + scene_dir = random.choice(self.scene_dirs) + dataset_name = self.dataset_info[scene_dir] # 获取该场景所属的数据集 + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + + # 🔧 使用FramePack风格的段落选择 + segment_info = self.select_dynamic_segment_framepack(full_latents) + if segment_info is None: + continue + + # 🔧 修正:为所有帧创建pose embeddings(不带mask) + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) + if all_camera_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入(在这里处理0填充) + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 简化:像train_sekai_walking一样处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 # condition帧标记为1 + mask = mask.view(-1, 1) + + # 添加mask到camera embeddings + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 - 现在都有固定长度 + "latents": framepack_inputs['latents'], # 主要预测目标 + "clean_latents": framepack_inputs['clean_latents'], # 条件帧(2帧) + "clean_latents_2x": framepack_inputs['clean_latents_2x'], # 2x条件帧(2帧,不足用0填充) + "clean_latents_4x": framepack_inputs['clean_latents_4x'], # 4x条件帧(16帧,不足用0填充) + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], # 固定长度 + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], # 固定长度 + + # 🔧 简化:直接传递带mask的camera embeddings + "camera": camera_with_mask, # 所有帧的camera embeddings(带mask) + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + "condition_frames": n, + "target_frames": m, + "scene_name": os.path.basename(scene_dir), + "dataset_name": dataset_name, # 🔧 新增:记录数据集名称 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_4 import WanModelFuture4 + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture4) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture4") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder - 参考hunyuan_video_packed.py + self.add_framepack_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件 - 参考hunyuan_video_packed.py""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤 - 修正维度处理""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + # 🔧 准备FramePack风格的输入 - 确保有batch维度 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 条件输入(处理空张量和维度) + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 🔧 索引(处理空张量) + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # 🔧 简化:直接使用带mask的camera embeddings + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 # 10%概率丢弃camera条件 + if random.random() < camera_dropout_prob: + # 创建零camera embedding + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: # 80%概率添加噪声 + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, # 🔧 简化:直接传递带mask的camera embeddings + # 🔧 FramePack风格的条件输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss(现在noise_pred只包含预测目标,不包含条件部分) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:', loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, + "target_frames_original": original_target_frames, + "has_clean_latents": clean_latents is not None, + "has_clean_latents_2x": clean_latents_2x is not None, + "has_clean_latents_4x": clean_latents_4x is not None, + "total_frames_compressed": target_frames, + "total_frames_original": original_target_frames, + "dataset_name": dataset_name, # 🔧 新增:记录数据集名称 + "scene_name": scene_name, # 🔧 新增:记录场景名称 + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack_4" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_framepack.ckpt")) + print(f"Saved FramePack model checkpoint: step{current_step}_framepack.ckpt") + +def train_dynamic(args): + """训练支持FramePack机制的动态历史长度模型 - 支持多数据集""" + # 🔧 修改:支持多个数据集路径 + dataset_paths = [ + "/share_zhuyixuan05/zhuyixuan05/sekai-game-drone", + "/share_zhuyixuan05/zhuyixuan05/sekai-game-walking" + ] + + dataset = DynamicSekaiDataset( + dataset_paths, # 🔧 传入多个数据集路径 + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="sekai-multi-dataset-framepack-recam", # 🔧 修改项目名称 + name=f"multi-dataset-framepack-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train FramePack Dynamic ReCamMaster with Multiple Datasets") + # 🔧 修改:dataset_path参数现在在代码中硬编码,但保留以便兼容 + parser.add_argument("--dataset_path", type=str, + default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking", + help="主数据集路径(实际会使用代码中的多数据集配置)") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=8000) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=8, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=120, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack_4/step34290_framepack.ckpt") + + args = parser.parse_args() + + print("🔧 使用多数据集训练:") + print(" - /share_zhuyixuan05/zhuyixuan05/sekai-game-drone") + print(" - /share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_sekai_walking.py b/scripts/train_sekai_walking.py new file mode 100644 index 0000000000000000000000000000000000000000..156f383fcc3be98ef55cb265724a2e2945a2ae1a --- /dev/null +++ b/scripts/train_sekai_walking.py @@ -0,0 +1,583 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier + +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def get_traj_position_change(cam_c2w, stride=1): + positions = cam_c2w[:, :3, 3] + + traj_coord = [] + tarj_angle = [] + for i in range(0, len(positions) - 2 * stride): + v1 = positions[i + stride] - positions[i] + v2 = positions[i + 2 * stride] - positions[i + stride] + + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(v1, v2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + + traj_coord.append(v1) + tarj_angle.append(angle) + + # traj_coord: list of coordinate changes, each element is a [dx, dy, dz] + # tarj_angle: list of position angle changes, each element is an angle in range (0, 180) + return traj_coord, tarj_angle + +def get_traj_rotation_change(cam_c2w, stride=1): + rotations = cam_c2w[:, :3, :3] + + traj_rot_angle = [] + for i in range(0, len(rotations) - stride): + z1 = rotations[i][:, 2] + z2 = rotations[i + stride][:, 2] + + norm1 = np.linalg.norm(z1) + norm2 = np.linalg.norm(z2) + if norm1 < 1e-6 or norm2 < 1e-6: + continue + + cos_angle = np.dot(z1, z2) / (norm1 * norm2) + angle = np.degrees(np.arccos(np.clip(cos_angle, -1.0, 1.0))) + traj_rot_angle.append(angle) + + # traj_rot_angle: list of rotation angle changes, each element is an angle in range (0, 180) + return traj_rot_angle + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + + 参数: + pose_a: 相机A的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机A坐标系的变换 (world → camera A) + pose_b: 相机B的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机B坐标系的变换 (world → camera B) + use_torch: 是否使用PyTorch进行计算,默认使用NumPy + + 返回: + relative_pose: 相对位姿矩阵 (4x4),表示从相机A坐标系到相机B坐标系的变换 + (camera A → camera B) + """ + # 检查输入形状 + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + # 确保输入是PyTorch张量 + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + # 确保输入是NumPy数组 + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +class DynamicSekaiDataset(torch.utils.data.Dataset): + """支持动态历史长度的NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + # print(f"Found {len(self.scene_dirs)} scenes with encoded data") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 预处理设置 + # self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(height, width)), + # v2.Resize(size=(height, width), antialias=True), + # v2.ToTensor(), + # v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + # ]) + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment(self, full_latents): + """动态选择条件帧和目标帧 - 修正版本处理VAE时间压缩""" + total_lens = full_latents.shape[1] + # print(f"原始总帧数: {total_frames}, 压缩后: {compressed_total_frames}") + # print(f"原始关键帧: {keyframe_indices[:5]}..., 压缩后: {compressed_keyframe_indices[:5]}...") + + # 随机选择条件帧长度(基于压缩后的帧数) + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + + ratio = random.random() + print('ratio:',ratio) + if ratio<0.15: + condition_frames_compressed = 1 + elif 0.15<=ratio<0.3: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = start_frame_compressed + + # 🔧 找到对应的原始关键帧索引用于pose查找 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed,target_end_compressed): + keyframe_original_idx.append(compressed_idx*4) + + + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'keyframe_original_idx': keyframe_original_idx, # 原始关键帧索引 + + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + + def create_pose_embeddings(self, cam_data, segment_info): + """创建pose embeddings - 修正版本,确保与latent帧数对齐""" + cam_data_seq = cam_data['extrinsic'] # 300 * 4 * 4 + keyframe_original_idx = segment_info['keyframe_original_idx'] + # target_keyframe_indices = segment_info['target_keyframe_indices'] + + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + # frame_range = cam_data_seq[start_frame:end_frame] + + relative_cams = [] + for idx in keyframe_original_idx: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx+4] + # print('cam_prev:',cam_prev) + # print('idx:',idx) + # assert False + relative_cam = compute_relative_pose(cam_prev,cam_next) + # print('relative_cam:',relative_cam) + # assert False + relative_cams.append(torch.as_tensor(relative_cam[:3,:])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + # print(pose_embedding.shape) + # assert False + # print() + # traj_pos_coord_full, tarj_pos_angle_full = get_traj_position_change(cam_data_seq, self.time_compression_ratio) + # traj_rot_angle_full = get_traj_rotation_change(cam_data_seq, self.time_compression_ratio) + + # motion_emb = + + return { + 'camera': pose_embedding + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + # with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + # scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + # expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + # print(f"场景 {os.path.basename(scene_dir)}: 原始帧数={scene_info['total_frames']}, " + # f"预期latent帧数={expected_latent_frames}, 实际latent帧数={actual_latent_frames}") + + # if abs(actual_latent_frames - expected_latent_frames) > 2: # 允许小的舍入误差 + # print(f"⚠️ Latent帧数不匹配,跳过此样本") + # continue + + # 动态选择段落 + segment_info = self.select_dynamic_segment(full_latents) + if segment_info is None: + continue + # print("segment_info:",segment_info) + # 创建pose embeddings + pose_data = self.create_pose_embeddings(cam_data, segment_info) + if pose_data is None: + continue + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + + pose_data["camera"] = torch.cat([pose_data["camera"], mask], dim=1) + # print(pose_data['camera'].shape) + # assert False + # 🔧 使用压缩后的索引提取latent段落 + start_frame = segment_info['start_frame'] # 已经是压缩后的索引 + condition_frames = segment_info['condition_frames'] # 已经是压缩后的帧数 + target_frames = segment_info['target_frames'] # 已经是压缩后的帧数 + + # print(f"提取latent段落: start={start_frame}, condition={condition_frames}, target={target_frames}") + # print(f"Full latents shape: {full_latents.shape}") + + # # 确保索引不越界 + # if start_frame + condition_frames + target_frames > full_latents.shape[1]: + # print(f"索引越界,跳过: {start_frame + condition_frames + target_frames} > {full_latents.shape[1]}") + # continue + + condition_latents = full_latents[:, start_frame:start_frame+condition_frames, :, :] + target_latents = full_latents[:, start_frame+condition_frames:start_frame+condition_frames+target_frames, :, :] + + # print(f"Condition latents shape: {condition_latents.shape}") + # print(f"Target latents shape: {target_latents.shape}") + + # 拼接latents [condition, target] + combined_latents = torch.cat([condition_latents, target_latents], dim=1) + + result = { + "latents": combined_latents, + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "camera": pose_data['camera'], + + "condition_frames": condition_frames, # 压缩后的帧数 + "target_frames": target_frames, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + # 获取动态长度信息(这些已经是压缩后的帧数) + condition_frames = batch["condition_frames"][0].item() # 压缩后的condition长度 + target_frames = batch["target_frames"][0].item() # 压缩后的target长度 + + # 🔧 获取原始帧数用于日志记录 + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + # Data + latents = batch["latents"].to(self.device) + # print(f"压缩后condition帧数: {condition_frames}, target帧数: {target_frames}") + # print(f"原始condition帧数: {original_condition_frames}, target帧数: {original_target_frames}") + # print(f"Latents shape: {latents.shape}") + + # 裁剪空间尺寸以节省内存 + # target_height, target_width = 50, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + # print(f"裁剪后latents shape: {latents.shape}") + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 关键:使用压缩后的condition长度 + # condition部分保持clean,只对target部分加噪 + noisy_latents[:, :, :condition_frames, ...] = origin_latents[:, :, :condition_frames, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + # print(f"targe尺寸: {training_target.shape}") + # 预测噪声 + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + # print(f"pred尺寸: {training_target.shape}") + # 🔧 只对target部分计算loss(使用压缩后的索引) + target_noise_pred = noise_pred[:, :, condition_frames:condition_frames+target_frames, ...] + target_training_target = training_target[:, :, condition_frames:condition_frames+target_frames, ...] + + loss = torch.nn.functional.mse_loss(target_noise_pred.float(), target_training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:',loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, # 压缩后的帧数000 + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, # 原始帧数 + "target_frames_original": original_target_frames, + "total_frames_compressed": condition_frames + target_frames, + "total_frames_original": original_condition_frames + original_target_frames, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/home/zhuyixuan05/ReCamMaster/sekai_walking" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_dynamic.ckpt")) + print(f"Saved dynamic model checkpoint: step{current_step}_dynamic.ckpt") + +def train_dynamic(args): + """训练支持动态历史长度的模型""" + dataset = DynamicSekaiDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-recam", + name=f"dynamic-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=8000) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_sekai_walking_noise.py b/scripts/train_sekai_walking_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..26192b73474a58dbb3817bc6fb7702b920e90634 --- /dev/null +++ b/scripts/train_sekai_walking_noise.py @@ -0,0 +1,596 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import os +import json +import torch +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +from scipy.spatial.transform import Rotation as R + +import pdb +# cam_c2w, [N * 4 * 4] +# stride, frame stride +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + +def compute_relative_pose(pose_a, pose_b, use_torch=False): + """ + 计算相机B相对于相机A的相对位姿矩阵 + + 参数: + pose_a: 相机A的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机A坐标系的变换 (world → camera A) + pose_b: 相机B的外参矩阵 (4x4),可以是numpy数组或PyTorch张量 + 表示从世界坐标系到相机B坐标系的变换 (world → camera B) + use_torch: 是否使用PyTorch进行计算,默认使用NumPy + + 返回: + relative_pose: 相对位姿矩阵 (4x4),表示从相机A坐标系到相机B坐标系的变换 + (camera A → camera B) + """ + # 检查输入形状 + assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}" + assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}" + + if use_torch: + # 确保输入是PyTorch张量 + if not isinstance(pose_a, torch.Tensor): + pose_a = torch.from_numpy(pose_a).float() + if not isinstance(pose_b, torch.Tensor): + pose_b = torch.from_numpy(pose_b).float() + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = torch.inverse(pose_a) + relative_pose = torch.matmul(pose_b, pose_a_inv) + else: + # 确保输入是NumPy数组 + if not isinstance(pose_a, np.ndarray): + pose_a = np.array(pose_a, dtype=np.float32) + if not isinstance(pose_b, np.ndarray): + pose_b = np.array(pose_b, dtype=np.float32) + + # 计算相对位姿: relative_pose = pose_b × inverse(pose_a) + pose_a_inv = np.linalg.inv(pose_a) + relative_pose = np.matmul(pose_b, pose_a_inv) + + return relative_pose + + +class DynamicSekaiDataset(torch.utils.data.Dataset): + """支持动态历史长度的NuScenes数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # 🔧 新增:VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + # print(f"Found {len(self.scene_dirs)} scenes with encoded data") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + # 预处理设置 + # self.frame_process = v2.Compose([ + # v2.CenterCrop(size=(height, width)), + # v2.Resize(size=(height, width), antialias=True), + # v2.ToTensor(), + # v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + # ]) + + def calculate_relative_rotation(self, current_rotation, reference_rotation): + """计算相对旋转四元数""" + q_current = torch.tensor(current_rotation, dtype=torch.float32) + q_ref = torch.tensor(reference_rotation, dtype=torch.float32) + + q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]]) + + w1, x1, y1, z1 = q_ref_inv + w2, x2, y2, z2 = q_current + + relative_rotation = torch.tensor([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + ]) + + return relative_rotation + + def select_dynamic_segment(self, full_latents): + """动态选择条件帧和目标帧 - 修正版本处理VAE时间压缩""" + total_lens = full_latents.shape[1] + # print(f"原始总帧数: {total_frames}, 压缩后: {compressed_total_frames}") + # print(f"原始关键帧: {keyframe_indices[:5]}..., 压缩后: {compressed_keyframe_indices[:5]}...") + + # 随机选择条件帧长度(基于压缩后的帧数) + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(max_condition_compressed,total_lens - target_frames_compressed) + # min_condition_compressed = min() + + ratio = random.random() + print('ratio:',ratio) + if ratio<0.15: + condition_frames_compressed = 1 + elif 0.15<=ratio<0.3: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + + + # 使用条件段的最后一个关键帧作为reference + reference_keyframe_compressed = start_frame_compressed + + # 🔧 找到对应的原始关键帧索引用于pose查找 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed,target_end_compressed): + keyframe_original_idx.append(compressed_idx) + + + + return { + 'start_frame': start_frame_compressed, # 压缩后的起始帧 + 'condition_frames': condition_frames_compressed, # 压缩后的条件帧数 + 'target_frames': target_frames_compressed, # 压缩后的目标帧数 + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + 'keyframe_original_idx': keyframe_original_idx, # 原始关键帧索引 + + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, # 用于记录 + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + + def create_pose_embeddings(self, cam_data, segment_info): + """创建pose embeddings - 修正版本,确保与latent帧数对齐""" + cam_data_seq = cam_data['extrinsic'] # 300 * 4 * 4 + # print(cam_data_seq.shape) + keyframe_original_idx = segment_info['keyframe_original_idx'] + # target_keyframe_indices = segment_info['target_keyframe_indices'] + + start_frame = segment_info['start_frame'] * self.time_compression_ratio + end_frame = segment_info['target_range'][1] * self.time_compression_ratio + # frame_range = cam_data_seq[start_frame:end_frame] + + relative_cams = [] + for idx in keyframe_original_idx: + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx+1] + # print('cam_prev:',cam_prev) + # print('idx:',idx) + # assert False + relative_cam = compute_relative_pose_matrix(cam_prev,cam_next) + # print(relative_cam) + # print('relative_cam:',relative_cam) + # assert False + relative_cams.append(torch.as_tensor(relative_cam[:3,:])) + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + # print(pose_embedding) + pose_embedding = pose_embedding.to(torch.bfloat16) + + # print(pose_embedding.shape) + # assert False + # print() + # traj_pos_coord_full, tarj_pos_angle_full = get_traj_position_change(cam_data_seq, self.time_compression_ratio) + # traj_rot_angle_full = get_traj_rotation_change(cam_data_seq, self.time_compression_ratio) + + # motion_emb = + + return { + 'camera': pose_embedding + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载场景信息 + # with open(os.path.join(scene_dir, "scene_info.json"), 'r') as f: + # scene_info = json.load(f) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + # expected_latent_frames = scene_info['total_frames'] // self.time_compression_ratio + actual_latent_frames = full_latents.shape[1] + + # print(f"场景 {os.path.basename(scene_dir)}: 原始帧数={scene_info['total_frames']}, " + # f"预期latent帧数={expected_latent_frames}, 实际latent帧数={actual_latent_frames}") + + # if abs(actual_latent_frames - expected_latent_frames) > 2: # 允许小的舍入误差 + # print(f"⚠️ Latent帧数不匹配,跳过此样本") + # continue + + # 动态选择段落 + segment_info = self.select_dynamic_segment(full_latents) + # print(segment_info) + if segment_info is None: + continue + # print("segment_info:",segment_info) + # 创建pose embeddings + pose_data = self.create_pose_embeddings(cam_data, segment_info) + if pose_data is None: + continue + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 + mask = mask.view(-1, 1) + + + pose_data["camera"] = torch.cat([pose_data["camera"], mask], dim=1) + # print(pose_data['camera'].shape) + # assert False + # 🔧 使用压缩后的索引提取latent段落 + start_frame = segment_info['start_frame'] # 已经是压缩后的索引 + condition_frames = segment_info['condition_frames'] # 已经是压缩后的帧数 + target_frames = segment_info['target_frames'] # 已经是压缩后的帧数 + + # print(f"提取latent段落: start={start_frame}, condition={condition_frames}, target={target_frames}") + # print(f"Full latents shape: {full_latents.shape}") + + # # 确保索引不越界 + # if start_frame + condition_frames + target_frames > full_latents.shape[1]: + # print(f"索引越界,跳过: {start_frame + condition_frames + target_frames} > {full_latents.shape[1]}") + # continue + + condition_latents = full_latents[:, start_frame:start_frame+condition_frames, :, :] + + + + target_latents = full_latents[:, start_frame+condition_frames:start_frame+condition_frames+target_frames, :, :] + + # print(f"Condition latents shape: {condition_latents.shape}") + # print(f"Target latents shape: {target_latents.shape}") + + # 拼接latents [condition, target] + combined_latents = torch.cat([condition_latents, target_latents], dim=1) + # print('latent:',combined_latents.requires_grad) + # print('prompt:',encoded_data["prompt_emb"]["context"].requires_grad) + # print('camera:',pose_data['camera'].requires_grad) + result = { + "latents": combined_latents, + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + "camera": pose_data['camera'], + + "condition_frames": condition_frames, # 压缩后的帧数 + "target_frames": target_frames, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + import traceback + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + +class DynamicLightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13 , dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + self.pipe.dit.load_state_dict(state_dict, strict=True) + print('load checkpoint:', resume_ckpt_path) + + self.freeze_parameters() + + # 只训练相机相关和注意力模块 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "sekai_dynamic/visualizations_dynamic" + os.makedirs(self.vis_dir, exist_ok=True) + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + # 获取动态长度信息(这些已经是压缩后的帧数) + condition_frames = batch["condition_frames"][0].item() # 压缩后的condition长度 + target_frames = batch["target_frames"][0].item() # 压缩后的target长度 + + # 🔧 获取原始帧数用于日志记录 + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + print("condition_frames:",batch["condition_frames"]) + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + # Data + latents = batch["latents"].to(self.device) + # print(f"压缩后condition帧数: {condition_frames}, target帧数: {target_frames}") + # print(f"原始condition帧数: {original_condition_frames}, target帧数: {original_target_frames}") + # print(f"Latents shape: {latents.shape}") + + # 裁剪空间尺寸以节省内存 + # target_height, target_width = 50, 70 + # current_height, current_width = latents.shape[3], latents.shape[4] + + # if current_height > target_height or current_width > target_width: + # h_start = (current_height - target_height) // 2 + # w_start = (current_width - target_width) // 2 + # latents = latents[:, :, :, + # h_start:h_start+target_height, + # w_start:w_start+target_width] + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + # print(f"裁剪后latents shape: {latents.shape}") + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + cam_emb = batch["camera"].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + noisy_condition_latents = copy.deepcopy(latents[:, :, :condition_frames, ...]) + is_add_noise = random.random() + if is_add_noise > 0.2: + # add noise to condition + noise_cond = torch.randn_like(latents[:, :, :condition_frames, ...]) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(latents[:, :, :condition_frames, ...], noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # 🔧 关键:使用压缩后的condition长度 + # condition部分保持clean,只对target部分加噪 + noisy_latents[:, :, :condition_frames, ...] = noisy_condition_latents #origin_latents[:, :, :condition_frames, ...] + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + # print(f"targe尺寸: {training_target.shape}") + # 预测噪声 + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, cam_emb=cam_emb, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + # print(f"pred尺寸: {training_target.shape}") + # 🔧 只对target部分计算loss(使用压缩后的索引) + target_noise_pred = noise_pred[:, :, condition_frames:condition_frames+target_frames, ...] + target_training_target = training_target[:, :, condition_frames:condition_frames+target_frames, ...] + + loss = torch.nn.functional.mse_loss(target_noise_pred.float(), target_training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print('--------loss------------:',loss) + + # 记录额外信息 + wandb.log({ + "train_loss": loss.item(), + "timestep": timestep.item(), + "condition_frames_compressed": condition_frames, # 压缩后的帧数000 + "target_frames_compressed": target_frames, + "condition_frames_original": original_condition_frames, # 原始帧数 + "target_frames_original": original_target_frames, + "total_frames_compressed": condition_frames + target_frames, + "total_frames_original": original_condition_frames + original_target_frames, + "global_step": self.global_step + }) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/train_0" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}_dynamic.ckpt")) + print(f"Saved dynamic model checkpoint: step{current_step}_dynamic.ckpt") + +def train_dynamic(args): + """训练支持动态历史长度的模型""" + dataset = DynamicSekaiDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = DynamicLightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + wandb.init( + project="nuscenes-dynamic-recam", + name=f"dynamic-{args.min_condition_frames}-{args.max_condition_frames}", + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + ) + trainer.fit(model, dataloader) + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Train Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=8000) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default=None) + + args = parser.parse_args() + + train_dynamic(args) \ No newline at end of file diff --git a/scripts/train_spatialvid.py b/scripts/train_spatialvid.py new file mode 100644 index 0000000000000000000000000000000000000000..7058dd5046418e124119be569e4f88ce27ebdf2d --- /dev/null +++ b/scripts/train_spatialvid.py @@ -0,0 +1,663 @@ +import torch +import torch.nn as nn +import lightning as pl +import wandb +import os +import copy +from diffsynth import WanVideoReCamMasterPipeline, ModelManager +import json +import numpy as np +from PIL import Image +import imageio +import random +from torchvision.transforms import v2 +from einops import rearrange +from pose_classifier import PoseClassifier +from scipy.spatial.transform import Rotation as R +import traceback +import argparse + +def compute_relative_pose_matrix(pose1, pose2): + """ + 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel] + + 参数: + pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1] + pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2] + + 返回: + relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel + """ + # 分离平移向量和四元数 + t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1] + q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1] + t2 = pose2[:3] # 第i+1帧平移 + q2 = pose2[3:] # 第i+1帧四元数 + + # 1. 计算相对旋转矩阵 R_rel + rot1 = R.from_quat(q1) # 第i帧旋转 + rot2 = R.from_quat(q2) # 第i+1帧旋转 + rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆 + R_rel = rot_rel.as_matrix() # 转换为3×3矩阵 + + # 2. 计算相对平移向量 t_rel + R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆) + t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1) + + # 3. 组合为3×4矩阵 [R_rel | t_rel] + relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) + + return relative_matrix + + +class SpatialVidFramePackDataset(torch.utils.data.Dataset): + """支持FramePack机制的SpatialVid数据集""" + + def __init__(self, base_path, steps_per_epoch, + min_condition_frames=10, max_condition_frames=40, + target_frames=10, height=900, width=1600): + self.base_path = base_path + self.scenes_path = base_path + self.min_condition_frames = min_condition_frames + self.max_condition_frames = max_condition_frames + self.target_frames = target_frames + self.height = height + self.width = width + self.steps_per_epoch = steps_per_epoch + self.pose_classifier = PoseClassifier() + + # VAE时间压缩比例 + self.time_compression_ratio = 4 # VAE将时间维度压缩4倍 + + # 查找所有处理好的场景 + self.scene_dirs = [] + if os.path.exists(self.scenes_path): + for item in os.listdir(self.scenes_path): + scene_dir = os.path.join(self.scenes_path, item) + if os.path.isdir(scene_dir): + encoded_path = os.path.join(scene_dir, "encoded_video.pth") + if os.path.exists(encoded_path): + self.scene_dirs.append(scene_dir) + + print(f"🔧 找到 {len(self.scene_dirs)} 个SpatialVid场景") + assert len(self.scene_dirs) > 0, "No encoded scenes found!" + + def select_dynamic_segment_framepack(self, full_latents): + """🔧 FramePack风格的动态选择条件帧和目标帧 - SpatialVid版本""" + total_lens = full_latents.shape[1] + + min_condition_compressed = self.min_condition_frames // self.time_compression_ratio + max_condition_compressed = self.max_condition_frames // self.time_compression_ratio + target_frames_compressed = self.target_frames // self.time_compression_ratio + max_condition_compressed = min(max_condition_compressed, total_lens - target_frames_compressed) + + ratio = random.random() + #print('ratio:', ratio) + if ratio < 0.15: + condition_frames_compressed = 1 + elif 0.15 <= ratio < 0.9: + condition_frames_compressed = random.randint(min_condition_compressed, max_condition_compressed) + else: + condition_frames_compressed = target_frames_compressed + + # 确保有足够的帧数 + min_required_frames = condition_frames_compressed + target_frames_compressed + if total_lens < min_required_frames: + print(f"压缩后帧数不足: {total_lens} < {min_required_frames}") + return None + + # 随机选择起始位置(基于压缩后的帧数) + max_start = total_lens - min_required_frames - 1 + start_frame_compressed = random.randint(0, max_start) + + condition_end_compressed = start_frame_compressed + condition_frames_compressed + target_end_compressed = condition_end_compressed + target_frames_compressed + + # 🔧 FramePack风格的索引处理 + latent_indices = torch.arange(condition_end_compressed, target_end_compressed) # 只预测未来帧 + + # 🔧 根据实际的condition_frames_compressed生成索引 + # 1x帧:起始帧 + 最后1帧 + clean_latent_indices_start = torch.tensor([start_frame_compressed]) + clean_latent_1x_indices = torch.tensor([condition_end_compressed - 1]) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices]) + + # 🔧 2x帧:根据实际condition长度确定 + if condition_frames_compressed >= 2: + # 取最后2帧(如果有的话) + clean_latent_2x_start = max(start_frame_compressed, condition_end_compressed - 2) + clean_latent_2x_indices = torch.arange(clean_latent_2x_start-1, condition_end_compressed-1) + else: + # 如果condition帧数不足2帧,创建空索引 + clean_latent_2x_indices = torch.tensor([], dtype=torch.long) + + # 🔧 4x帧:根据实际condition长度确定,最多16帧 + if condition_frames_compressed >= 1: + # 取最多16帧的历史(如果有的话) + clean_4x_start = max(start_frame_compressed, condition_end_compressed - 16) + clean_latent_4x_indices = torch.arange(clean_4x_start-3, condition_end_compressed-3) + else: + clean_latent_4x_indices = torch.tensor([], dtype=torch.long) + + # 对应的原始关键帧索引 - SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = [] + for compressed_idx in range(start_frame_compressed, target_end_compressed): + keyframe_original_idx.append(compressed_idx) # SpatialVid使用1倍间隔 + + return { + 'start_frame': start_frame_compressed, + 'condition_frames': condition_frames_compressed, + 'target_frames': target_frames_compressed, + 'condition_range': (start_frame_compressed, condition_end_compressed), + 'target_range': (condition_end_compressed, target_end_compressed), + + # FramePack风格的索引 + 'latent_indices': latent_indices, + 'clean_latent_indices': clean_latent_indices, + 'clean_latent_2x_indices': clean_latent_2x_indices, + 'clean_latent_4x_indices': clean_latent_4x_indices, + + 'keyframe_original_idx': keyframe_original_idx, + 'original_condition_frames': condition_frames_compressed * self.time_compression_ratio, + 'original_target_frames': target_frames_compressed * self.time_compression_ratio, + } + + def create_pose_embeddings(self, cam_data, segment_info): + """🔧 创建SpatialVid风格的pose embeddings - camera间隔为1帧而非4帧""" + cam_data_seq = cam_data['extrinsic'] # N * 4 * 4 + + # 🔧 为所有帧(condition + target)计算camera embedding + # SpatialVid特有:每隔1帧而不是4帧 + keyframe_original_idx = segment_info['keyframe_original_idx'] + + relative_cams = [] + for idx in keyframe_original_idx: + if idx + 1 < len(cam_data_seq): + cam_prev = cam_data_seq[idx] + cam_next = cam_data_seq[idx + 1] # SpatialVid: 每隔1帧 + relative_cam = compute_relative_pose_matrix(cam_prev, cam_next) + relative_cams.append(torch.as_tensor(relative_cam[:3, :])) + else: + # 如果没有下一帧,使用零运动 + identity_cam = torch.zeros(3, 4) + relative_cams.append(identity_cam) + + if len(relative_cams) == 0: + return None + + pose_embedding = torch.stack(relative_cams, dim=0) + pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') + pose_embedding = pose_embedding.to(torch.bfloat16) + + return pose_embedding + + def prepare_framepack_inputs(self, full_latents, segment_info): + """🔧 准备FramePack风格的多尺度输入 - SpatialVid版本""" + # 🔧 修正:处理4维输入 [C, T, H, W],添加batch维度 + if len(full_latents.shape) == 4: + full_latents = full_latents.unsqueeze(0) # [C, T, H, W] -> [1, C, T, H, W] + B, C, T, H, W = full_latents.shape + else: + B, C, T, H, W = full_latents.shape + + # 主要latents(用于去噪预测) + latent_indices = segment_info['latent_indices'] + main_latents = full_latents[:, :, latent_indices, :, :] + + # 🔧 1x条件帧(起始帧 + 最后1帧) + clean_latent_indices = segment_info['clean_latent_indices'] + clean_latents = full_latents[:, :, clean_latent_indices, :, :] + + # 🔧 4x条件帧 - 总是16帧,直接用真实索引 + 0填充 + clean_latent_4x_indices = segment_info['clean_latent_4x_indices'] + + # 创建固定长度16的latents,初始化为0 + clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype) + clean_latent_4x_indices_final = torch.full((16,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的4x索引 + if len(clean_latent_4x_indices) > 0: + actual_4x_frames = len(clean_latent_4x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 16 - actual_4x_frames) + end_pos = 16 + actual_start = max(0, actual_4x_frames - 16) # 如果超过16帧,只取最后16帧 + + clean_latents_4x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_4x_indices[actual_start:], :, :] + clean_latent_4x_indices_final[start_pos:end_pos] = clean_latent_4x_indices[actual_start:] + + # 🔧 2x条件帧 - 总是2帧,直接用真实索引 + 0填充 + clean_latent_2x_indices = segment_info['clean_latent_2x_indices'] + + # 创建固定长度2的latents,初始化为0 + clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype) + clean_latent_2x_indices_final = torch.full((2,), -1, dtype=torch.long) # -1表示padding + + # 🔧 修正:检查是否有有效的2x索引 + if len(clean_latent_2x_indices) > 0: + actual_2x_frames = len(clean_latent_2x_indices) + # 从后往前填充,确保最新的帧在最后 + start_pos = max(0, 2 - actual_2x_frames) + end_pos = 2 + actual_start = max(0, actual_2x_frames - 2) # 如果超过2帧,只取最后2帧 + + clean_latents_2x[:, :, start_pos:end_pos, :, :] = full_latents[:, :, clean_latent_2x_indices[actual_start:], :, :] + clean_latent_2x_indices_final[start_pos:end_pos] = clean_latent_2x_indices[actual_start:] + + # 🔧 移除添加的batch维度,返回原始格式 + if B == 1: + main_latents = main_latents.squeeze(0) # [1, C, T, H, W] -> [C, T, H, W] + clean_latents = clean_latents.squeeze(0) + clean_latents_2x = clean_latents_2x.squeeze(0) + clean_latents_4x = clean_latents_4x.squeeze(0) + + return { + 'latents': main_latents, + 'clean_latents': clean_latents, + 'clean_latents_2x': clean_latents_2x, + 'clean_latents_4x': clean_latents_4x, + 'latent_indices': segment_info['latent_indices'], + 'clean_latent_indices': segment_info['clean_latent_indices'], + 'clean_latent_2x_indices': clean_latent_2x_indices_final, + 'clean_latent_4x_indices': clean_latent_4x_indices_final, + } + + def __getitem__(self, index): + while True: + try: + # 随机选择一个场景 + scene_dir = random.choice(self.scene_dirs) + + # 加载编码的视频数据 + encoded_data = torch.load( + os.path.join(scene_dir, "encoded_video.pth"), + weights_only=False, + map_location="cpu" + ) + + # 🔧 验证latent帧数是否符合预期 + full_latents = encoded_data['latents'] # [C, T, H, W] + cam_data = encoded_data['cam_emb'] + actual_latent_frames = full_latents.shape[1] + + # 动态选择段落 + segment_info = self.select_dynamic_segment_framepack(full_latents) + if segment_info is None: + continue + + # 创建pose embeddings - SpatialVid版本 + all_camera_embeddings = self.create_pose_embeddings(cam_data, segment_info) + if all_camera_embeddings is None: + continue + + # 🔧 准备FramePack风格的多尺度输入 + framepack_inputs = self.prepare_framepack_inputs(full_latents, segment_info) + + n = segment_info["condition_frames"] + m = segment_info['target_frames'] + + # 🔧 处理camera embedding with mask + mask = torch.zeros(n+m, dtype=torch.float32) + mask[:n] = 1.0 # condition帧标记为1 + mask = mask.view(-1, 1) + + # 添加mask到camera embeddings + camera_with_mask = torch.cat([all_camera_embeddings, mask], dim=1) + + result = { + # 🔧 FramePack风格的多尺度输入 + "latents": framepack_inputs['latents'], # 主要预测目标 + "clean_latents": framepack_inputs['clean_latents'], # 条件帧 + "clean_latents_2x": framepack_inputs['clean_latents_2x'], + "clean_latents_4x": framepack_inputs['clean_latents_4x'], + "latent_indices": framepack_inputs['latent_indices'], + "clean_latent_indices": framepack_inputs['clean_latent_indices'], + "clean_latent_2x_indices": framepack_inputs['clean_latent_2x_indices'], + "clean_latent_4x_indices": framepack_inputs['clean_latent_4x_indices'], + + # 🔧 直接传递带mask的camera embeddings + "camera": camera_with_mask, # 所有帧的camera embeddings(带mask) + + "prompt_emb": encoded_data["prompt_emb"], + "image_emb": encoded_data.get("image_emb", {}), + + "condition_frames": n, # 压缩后的帧数 + "target_frames": m, # 压缩后的帧数 + "scene_name": os.path.basename(scene_dir), + "dataset_name": "spatialvid", + # 🔧 新增:记录原始帧数用于调试 + "original_condition_frames": segment_info['original_condition_frames'], + "original_target_frames": segment_info['original_target_frames'], + } + + return result + + except Exception as e: + print(f"Error loading sample: {e}") + traceback.print_exc() + continue + + def __len__(self): + return self.steps_per_epoch + + +def replace_dit_model_in_manager(): + """在模型加载前替换DiT模型类""" + from diffsynth.models.wan_video_dit_recam_future import WanModelFuture + from diffsynth.configs.model_config import model_loader_configs + + # 修改model_loader_configs中的配置 + for i, config in enumerate(model_loader_configs): + keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config + + # 检查是否包含wan_video_dit模型 + if 'wan_video_dit' in model_names: + # 找到wan_video_dit的索引并替换为WanModelFuture + new_model_names = [] + new_model_classes = [] + + for name, cls in zip(model_names, model_classes): + if name == 'wan_video_dit': + new_model_names.append(name) # 保持名称不变 + new_model_classes.append(WanModelFuture) # 替换为新的类 + print(f"✅ 替换了模型类: {name} -> WanModelFuture") + else: + new_model_names.append(name) + new_model_classes.append(cls) + + # 更新配置 + model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) + + +class SpatialVidFramePackLightningModel(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + resume_ckpt_path=None + ): + super().__init__() + replace_dit_model_in_manager() # 在这里调用 + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + model_manager.load_models(["models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"]) + + self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # 🔧 添加FramePack的clean_x_embedder + self.add_framepack_components() + + # 添加相机编码器 + dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + for block in self.pipe.dit.blocks: + block.cam_encoder = nn.Linear(13, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) + + if resume_ckpt_path is not None: + state_dict = torch.load(resume_ckpt_path, map_location="cpu") + + # 🔧 过滤掉cam_encoder相关的权重,使其保持随机初始化 + filtered_state_dict = {} + for key, value in state_dict.items(): + if 'cam_encoder' not in key: + filtered_state_dict[key] = value + else: + print(f"🔧 跳过加载cam_encoder权重: {key}") + + # 使用strict=False允许部分权重不匹配 + missing_keys, unexpected_keys = self.pipe.dit.load_state_dict(filtered_state_dict, strict=False) + print(f'✅ 加载checkpoint: {resume_ckpt_path}') + print(f'🔧 cam_encoder保持随机初始化,未加载预训练权重') + if missing_keys: + print(f'⚠️ Missing keys (预期的,包含cam_encoder): {len(missing_keys)} keys') + if unexpected_keys: + print(f'⚠️ Unexpected keys: {unexpected_keys}') + #测试用 + + self.freeze_parameters() + + # 只训练相机相关和注意力模块以及FramePack相关组件 + for name, module in self.pipe.denoising_model().named_modules(): + if any(keyword in name for keyword in ["cam_encoder", "projector", "self_attn", "clean_x_embedder"]): + for param in module.parameters(): + param.requires_grad = True + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + # 创建可视化目录 + self.vis_dir = "spatialvid_framepack/visualizations" + os.makedirs(self.vis_dir, exist_ok=True) + + def add_framepack_components(self): + """🔧 添加FramePack相关组件""" + if not hasattr(self.pipe.dit, 'clean_x_embedder'): + inner_dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] + + class CleanXEmbedder(nn.Module): + def __init__(self, inner_dim): + super().__init__() + # 参考hunyuan_video_packed.py的设计 + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward(self, x, scale="1x"): + if scale == "1x": + return self.proj(x) + elif scale == "2x": + return self.proj_2x(x) + elif scale == "4x": + return self.proj_4x(x) + else: + raise ValueError(f"Unsupported scale: {scale}") + + self.pipe.dit.clean_x_embedder = CleanXEmbedder(inner_dim) + print("✅ 添加了FramePack的clean_x_embedder组件") + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + def training_step(self, batch, batch_idx): + """🔧 使用FramePack风格的训练步骤 - SpatialVid版本""" + condition_frames = batch["condition_frames"][0].item() + target_frames = batch["target_frames"][0].item() + + original_condition_frames = batch.get("original_condition_frames", [condition_frames * 4])[0] + original_target_frames = batch.get("original_target_frames", [target_frames * 4])[0] + + dataset_name = batch.get("dataset_name", ["unknown"])[0] + scene_name = batch.get("scene_name", ["unknown"])[0] + + # 🔧 准备FramePack风格的输入 - 确保有batch维度 + latents = batch["latents"].to(self.device) + if len(latents.shape) == 4: # [C, T, H, W] + latents = latents.unsqueeze(0) # -> [1, C, T, H, W] + + # 🔧 条件输入(处理空张量和维度) + clean_latents = batch["clean_latents"].to(self.device) if batch["clean_latents"].numel() > 0 else None + if clean_latents is not None and len(clean_latents.shape) == 4: + clean_latents = clean_latents.unsqueeze(0) + + clean_latents_2x = batch["clean_latents_2x"].to(self.device) if batch["clean_latents_2x"].numel() > 0 else None + if clean_latents_2x is not None and len(clean_latents_2x.shape) == 4: + clean_latents_2x = clean_latents_2x.unsqueeze(0) + + clean_latents_4x = batch["clean_latents_4x"].to(self.device) if batch["clean_latents_4x"].numel() > 0 else None + if clean_latents_4x is not None and len(clean_latents_4x.shape) == 4: + clean_latents_4x = clean_latents_4x.unsqueeze(0) + + # 🔧 索引(处理空张量) + latent_indices = batch["latent_indices"].to(self.device) + clean_latent_indices = batch["clean_latent_indices"].to(self.device) if batch["clean_latent_indices"].numel() > 0 else None + clean_latent_2x_indices = batch["clean_latent_2x_indices"].to(self.device) if batch["clean_latent_2x_indices"].numel() > 0 else None + clean_latent_4x_indices = batch["clean_latent_4x_indices"].to(self.device) if batch["clean_latent_4x_indices"].numel() > 0 else None + + # 🔧 直接使用带mask的camera embeddings + cam_emb = batch["camera"].to(self.device) + camera_dropout_prob = 0.1 # 10%概率丢弃camera条件 + if random.random() < camera_dropout_prob: + # 创建零camera embedding + cam_emb = torch.zeros_like(cam_emb) + print("应用camera dropout for CFG training") + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss计算 + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + + # 🔧 FramePack风格的噪声处理 + noisy_condition_latents = None + if clean_latents is not None: + noisy_condition_latents = copy.deepcopy(clean_latents) + is_add_noise = random.random() + if is_add_noise > 0.2: # 80%概率添加噪声 + noise_cond = torch.randn_like(clean_latents) + timestep_id_cond = torch.randint(0, self.pipe.scheduler.num_train_timesteps//4*3, (1,)) + timestep_cond = self.pipe.scheduler.timesteps[timestep_id_cond].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + noisy_condition_latents = self.pipe.scheduler.add_noise(clean_latents, noise_cond, timestep_cond) + + extra_input = self.pipe.prepare_extra_input(latents) + origin_latents = copy.deepcopy(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # 🔧 使用FramePack风格的forward调用 + noise_pred = self.pipe.denoising_model()( + noisy_latents, + timestep=timestep, + cam_emb=cam_emb, # 🔧 直接传递带mask的camera embeddings + # 🔧 FramePack风格的条件输入 + latent_indices=latent_indices, + clean_latents=noisy_condition_latents if noisy_condition_latents is not None else clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + **prompt_emb, + **extra_input, + **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + + # 计算loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + print(f'--------loss ({dataset_name})------------:', loss) + + return loss + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + def on_save_checkpoint(self, checkpoint): + checkpoint_dir = "/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_framepack_random" + os.makedirs(checkpoint_dir, exist_ok=True) + + current_step = self.global_step + checkpoint.clear() + + state_dict = self.pipe.denoising_model().state_dict() + torch.save(state_dict, os.path.join(checkpoint_dir, f"step{current_step}.ckpt")) + print(f"Saved SpatialVid FramePack model checkpoint: step{current_step}.ckpt") + + +def train_spatialvid_framepack(args): + """训练支持FramePack机制的SpatialVid模型""" + dataset = SpatialVidFramePackDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch, + min_condition_frames=args.min_condition_frames, + max_condition_frames=args.max_condition_frames, + target_frames=args.target_frames, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + + model = SpatialVidFramePackLightningModel( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + resume_ckpt_path=args.resume_ckpt_path, + ) + + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[], + logger=False + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train SpatialVid FramePack Dynamic ReCamMaster") + parser.add_argument("--dataset_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid") + parser.add_argument("--dit_path", type=str, default="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--steps_per_epoch", type=int, default=400) + parser.add_argument("--max_epochs", type=int, default=30) + parser.add_argument("--min_condition_frames", type=int, default=10, help="最小条件帧数") + parser.add_argument("--max_condition_frames", type=int, default=40, help="最大条件帧数") + parser.add_argument("--target_frames", type=int, default=32, help="目标帧数") + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--accumulate_grad_batches", type=int, default=1) + parser.add_argument("--training_strategy", type=str, default="deepspeed_stage_1") + parser.add_argument("--use_gradient_checkpointing", action="store_true") + parser.add_argument("--use_gradient_checkpointing_offload", action="store_true") + parser.add_argument("--resume_ckpt_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt") + + args = parser.parse_args() + + print("🔧 开始训练SpatialVid FramePack模型:") + print(f"📁 数据集路径: {args.dataset_path}") + print(f"🎯 条件帧范围: {args.min_condition_frames}-{args.max_condition_frames}") + print(f"🎯 目标帧数: {args.target_frames}") + print("🔧 特殊优化:") + print(" - 使用WanModelFuture模型架构") + print(" - 添加FramePack多尺度输入支持") + print(" - SpatialVid特有:camera间隔为1帧") + print(" - CFG训练支持(10%概率camera dropout)") + + train_spatialvid_framepack(args) \ No newline at end of file diff --git a/scripts/wan_video_dit_recam.py b/scripts/wan_video_dit_recam.py new file mode 100644 index 0000000000000000000000000000000000000000..58bc3e65b4d146e6650d19527bd86402d03f3bac --- /dev/null +++ b/scripts/wan_video_dit_recam.py @@ -0,0 +1,502 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange +from .utils import hash_state_dict_keys +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, causal=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads, causal=False): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, causal: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + x = x.to(self.q.weight.dtype) + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + #self.self_attn = SelfAttention(dim, num_heads, eps, causal=True) # Enable causal masking + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward(self, x, context, cam_emb, t_mod, freqs): + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + + # encode camera + cam_emb = self.cam_encoder(cam_emb) + cam_emb = torch.cat([torch.zeros_like(cam_emb).repeat(1, 2, 1), cam_emb], dim=1) + #cam_emb = cam_emb.repeat(1,2,1) + cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, 30, 52, 1) + cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d') + input_x = input_x + cam_emb + x = x + gate_msa * self.projector(self.self_attn(input_x, freqs)) + + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = x + gate_mlp * self.ffn(input_x) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + + def forward(self, x): + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + ): + super().__init__() + self.dim = dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280 + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + cam_emb: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, cam_emb, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, cam_emb, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, cam_emb, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + @staticmethod + def state_dict_converter(): + return WanModelStateDictConverter() + + +class WanModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict_, config + + def from_civitai(self, state_dict): + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + else: + config = {} + return state_dict, config