| | import os |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import imageio |
| | import argparse |
| | from diffsynth import WanVideoReCamMasterPipeline, ModelManager |
| | from tqdm import tqdm |
| | import json |
| |
|
| | class VideoDecoder: |
| | def __init__(self, vae_path, device="cuda"): |
| | """初始化视频解码器""" |
| | self.device = device |
| | |
| | |
| | model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") |
| | model_manager.load_models([vae_path]) |
| | |
| | |
| | self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager) |
| | self.pipe = self.pipe.to(device) |
| | |
| | |
| | self.pipe.vae = self.pipe.vae.to(device) |
| | if hasattr(self.pipe.vae, 'model'): |
| | self.pipe.vae.model = self.pipe.vae.model.to(device) |
| | |
| | print(f"✅ VAE解码器初始化完成,设备: {device}") |
| |
|
| | def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): |
| | """ |
| | 将latents解码为视频 - 修正版本,修复维度处理问题 |
| | """ |
| | print(f"🔧 开始解码latents...") |
| | print(f"输入latents形状: {latents.shape}") |
| | print(f"输入latents设备: {latents.device}") |
| | print(f"输入latents数据类型: {latents.dtype}") |
| | |
| | |
| | if len(latents.shape) == 4: |
| | latents = latents.unsqueeze(0) |
| | |
| | |
| | model_dtype = next(self.pipe.vae.parameters()).dtype |
| | model_device = next(self.pipe.vae.parameters()).device |
| | |
| | print(f"模型设备: {model_device}") |
| | print(f"模型数据类型: {model_dtype}") |
| | |
| | |
| | latents = latents.to(device=model_device, dtype=model_dtype) |
| | |
| | print(f"解码latents形状: {latents.shape}") |
| | print(f"解码latents设备: {latents.device}") |
| | print(f"解码latents数据类型: {latents.dtype}") |
| | |
| | |
| | self.pipe.device = model_device |
| | |
| | |
| | with torch.no_grad(): |
| | try: |
| | if tiled: |
| | print("🔧 尝试tiled解码...") |
| | decoded_video = self.pipe.decode_video( |
| | latents, |
| | tiled=True, |
| | tile_size=tile_size, |
| | tile_stride=tile_stride |
| | ) |
| | else: |
| | print("🔧 使用非tiled解码...") |
| | decoded_video = self.pipe.decode_video(latents, tiled=False) |
| | |
| | except Exception as e: |
| | print(f"decode_video失败,错误: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | |
| | try: |
| | print("🔧 尝试直接调用VAE解码...") |
| | decoded_video = self.pipe.vae.decode( |
| | latents.squeeze(0), |
| | device=model_device, |
| | tiled=False |
| | ) |
| | |
| | if len(decoded_video.shape) == 4: |
| | decoded_video = decoded_video.unsqueeze(0) |
| | except Exception as e2: |
| | print(f"直接VAE解码也失败: {e2}") |
| | raise e2 |
| | |
| | print(f"解码后视频形状: {decoded_video.shape}") |
| | |
| | |
| | video_np = None |
| | |
| | if len(decoded_video.shape) == 5: |
| | |
| | if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]): |
| | |
| | print("🔧 检测到格式: [B, C, T, H, W]") |
| | video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() |
| | elif decoded_video.shape[1] == 3: |
| | |
| | print("🔧 检测到可能的格式: [B, C, T, H, W]") |
| | video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() |
| | elif decoded_video.shape[-1] == 3: |
| | |
| | print("🔧 检测到格式: [B, T, H, W, C]") |
| | video_np = decoded_video[0].to(torch.float32).cpu().numpy() |
| | else: |
| | |
| | shape = list(decoded_video.shape) |
| | if 3 in shape: |
| | channel_dim = shape.index(3) |
| | print(f"🔧 检测到通道维度在位置: {channel_dim}") |
| | |
| | if channel_dim == 1: |
| | video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() |
| | elif channel_dim == 4: |
| | video_np = decoded_video[0].to(torch.float32).cpu().numpy() |
| | else: |
| | print(f"⚠️ 未知的通道维度位置: {channel_dim}") |
| | raise ValueError(f"Cannot handle channel dimension at position {channel_dim}") |
| | else: |
| | print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}") |
| | raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}") |
| | |
| | elif len(decoded_video.shape) == 4: |
| | |
| | if decoded_video.shape[-1] == 3: |
| | video_np = decoded_video.to(torch.float32).cpu().numpy() |
| | elif decoded_video.shape[0] == 3: |
| | video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() |
| | else: |
| | print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}") |
| | raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}") |
| | else: |
| | print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}") |
| | raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}") |
| | |
| | if video_np is None: |
| | raise ValueError("Failed to convert video tensor to numpy array") |
| | |
| | print(f"转换后视频数组形状: {video_np.shape}") |
| | |
| | |
| | if len(video_np.shape) != 4: |
| | raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}") |
| | |
| | if video_np.shape[-1] != 3: |
| | print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}") |
| | print(f"完整形状: {video_np.shape}") |
| | |
| | if video_np.shape[0] == 3: |
| | print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]") |
| | video_np = np.transpose(video_np, (1, 2, 3, 0)) |
| | elif video_np.shape[1] == 3: |
| | print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]") |
| | video_np = np.transpose(video_np, (0, 2, 3, 1)) |
| | else: |
| | raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels") |
| | |
| | |
| | video_np = (video_np * 0.5 + 0.5).clip(0, 1) |
| | video_np = (video_np * 255).astype(np.uint8) |
| | |
| | print(f"最终视频数组形状: {video_np.shape}") |
| | print(f"视频数组值范围: {video_np.min()} - {video_np.max()}") |
| | |
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | |
| | try: |
| | with imageio.get_writer(output_path, fps=10, quality=8) as writer: |
| | for frame_idx, frame in enumerate(video_np): |
| | |
| | if len(frame.shape) != 3 or frame.shape[-1] != 3: |
| | print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}") |
| | continue |
| | |
| | writer.append_data(frame) |
| | if frame_idx % 10 == 0: |
| | print(f" 写入帧 {frame_idx}/{len(video_np)}") |
| | except Exception as e: |
| | print(f"保存视频失败: {e}") |
| | |
| | debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames") |
| | os.makedirs(debug_dir, exist_ok=True) |
| | |
| | for i in range(min(5, len(video_np))): |
| | frame = video_np[i] |
| | debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png") |
| | try: |
| | if len(frame.shape) == 3 and frame.shape[-1] == 3: |
| | Image.fromarray(frame).save(debug_path) |
| | print(f"调试: 保存帧 {i} 到 {debug_path}") |
| | else: |
| | print(f"调试: 帧 {i} 形状异常: {frame.shape}") |
| | except Exception as e2: |
| | print(f"调试: 保存帧 {i} 失败: {e2}") |
| | raise e |
| | |
| | print(f"✅ 视频保存到: {output_path}") |
| | return video_np |
| |
|
| | def save_frames_as_images(self, video_np, output_dir, prefix="frame"): |
| | """将视频帧保存为单独的图像文件""" |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | for i, frame in enumerate(video_np): |
| | frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png") |
| | |
| | if len(frame.shape) == 3 and frame.shape[-1] == 3: |
| | Image.fromarray(frame).save(frame_path) |
| | else: |
| | print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}") |
| | |
| | print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}") |
| |
|
| | def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"): |
| | """解码单个episode的编码数据 - 修正版本""" |
| | print(f"\n🔧 解码episode: {encoded_pth_path}") |
| | |
| | |
| | try: |
| | encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu") |
| | print(f"✅ 成功加载编码数据") |
| | except Exception as e: |
| | print(f"❌ 加载编码数据失败: {e}") |
| | return False |
| | |
| | |
| | print("🔍 编码数据结构:") |
| | for key, value in encoded_data.items(): |
| | if isinstance(value, torch.Tensor): |
| | print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}") |
| | elif isinstance(value, dict): |
| | print(f" - {key}: dict with keys {list(value.keys())}") |
| | else: |
| | print(f" - {key}: {type(value)}") |
| | |
| | |
| | latents = encoded_data.get('latents') |
| | if latents is None: |
| | print("❌ 未找到latents数据") |
| | return False |
| | |
| | |
| | if latents.device != torch.device('cpu'): |
| | latents = latents.cpu() |
| | print(f"🔧 将latents移动到CPU: {latents.device}") |
| | |
| | episode_info = encoded_data.get('episode_info', {}) |
| | episode_idx = episode_info.get('episode_idx', 'unknown') |
| | total_frames = episode_info.get('total_frames', latents.shape[1] * 4) |
| | |
| | print(f"Episode信息:") |
| | print(f" - Episode索引: {episode_idx}") |
| | print(f" - Latents形状: {latents.shape}") |
| | print(f" - Latents设备: {latents.device}") |
| | print(f" - Latents数据类型: {latents.dtype}") |
| | print(f" - 原始总帧数: {total_frames}") |
| | print(f" - 压缩后帧数: {latents.shape[1]}") |
| | |
| | |
| | episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}" |
| | output_dir = os.path.join(output_base_dir, episode_name) |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | try: |
| | decoder = VideoDecoder(vae_path, device) |
| | except Exception as e: |
| | print(f"❌ 初始化解码器失败: {e}") |
| | return False |
| | |
| | |
| | video_output_path = os.path.join(output_dir, "decoded_video.mp4") |
| | try: |
| | video_np = decoder.decode_latents_to_video( |
| | latents, |
| | video_output_path, |
| | tiled=False, |
| | tile_size=(34, 34), |
| | tile_stride=(18, 16) |
| | ) |
| | |
| | |
| | frames_dir = os.path.join(output_dir, "frames") |
| | sample_frames = video_np[:min(10, len(video_np))] |
| | decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}") |
| | |
| | |
| | decode_info = { |
| | "source_pth": encoded_pth_path, |
| | "decoded_video_path": video_output_path, |
| | "latents_shape": list(latents.shape), |
| | "decoded_video_shape": list(video_np.shape), |
| | "original_total_frames": total_frames, |
| | "decoded_frames": len(video_np), |
| | "compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0, |
| | "latents_dtype": str(latents.dtype), |
| | "latents_device": str(latents.device), |
| | "vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0 |
| | } |
| | |
| | info_path = os.path.join(output_dir, "decode_info.json") |
| | with open(info_path, 'w') as f: |
| | json.dump(decode_info, f, indent=2) |
| | |
| | print(f"✅ Episode {episode_idx} 解码完成") |
| | print(f" - 原始帧数: {total_frames}") |
| | print(f" - 解码帧数: {len(video_np)}") |
| | print(f" - 压缩比: {decode_info['compression_ratio']:.2f}") |
| | print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}") |
| | return True |
| | |
| | except Exception as e: |
| | print(f"❌ 解码失败: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| | def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"): |
| | """批量解码episodes""" |
| | print(f"🔧 批量解码Open-X episodes") |
| | print(f"源目录: {encoded_base_dir}") |
| | print(f"输出目录: {output_base_dir}") |
| | |
| | |
| | episode_dirs = [] |
| | if os.path.exists(encoded_base_dir): |
| | for item in sorted(os.listdir(encoded_base_dir)): |
| | episode_dir = os.path.join(encoded_base_dir, item) |
| | if os.path.isdir(episode_dir): |
| | encoded_path = os.path.join(episode_dir, "encoded_video.pth") |
| | if os.path.exists(encoded_path): |
| | episode_dirs.append(encoded_path) |
| | |
| | print(f"找到 {len(episode_dirs)} 个编码的episodes") |
| | |
| | if max_episodes and len(episode_dirs) > max_episodes: |
| | episode_dirs = episode_dirs[:max_episodes] |
| | print(f"限制处理前 {max_episodes} 个episodes") |
| | |
| | |
| | success_count = 0 |
| | for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")): |
| | print(f"\n{'='*60}") |
| | print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}") |
| | |
| | success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device) |
| | if success: |
| | success_count += 1 |
| | |
| | print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)") |
| | |
| | print(f"\n🎉 批量解码完成!") |
| | print(f"总处理: {len(episode_dirs)} 个episodes") |
| | print(f"成功解码: {success_count} 个episodes") |
| | print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本") |
| | parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch", |
| | help="解码模式:single (单个episode) 或 batch (批量)") |
| | parser.add_argument("--encoded_pth", type=str, |
| | default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth", |
| | help="单个编码文件路径(single模式)") |
| | parser.add_argument("--encoded_base_dir", type=str, |
| | default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded", |
| | help="编码数据基础目录(batch模式)") |
| | parser.add_argument("--vae_path", type=str, |
| | default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", |
| | help="VAE模型路径") |
| | parser.add_argument("--output_dir", type=str, |
| | default="./decoded_results_fixed", |
| | help="解码输出目录") |
| | parser.add_argument("--max_episodes", type=int, default=5, |
| | help="最大解码episodes数量(batch模式,用于测试)") |
| | parser.add_argument("--device", type=str, default="cuda", |
| | help="计算设备") |
| | |
| | args = parser.parse_args() |
| | |
| | print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)") |
| | print(f"模式: {args.mode}") |
| | print(f"VAE路径: {args.vae_path}") |
| | print(f"输出目录: {args.output_dir}") |
| | print(f"设备: {args.device}") |
| | |
| | |
| | if args.device == "cuda" and not torch.cuda.is_available(): |
| | print("⚠️ CUDA不可用,切换到CPU") |
| | args.device = "cpu" |
| | |
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | |
| | if args.mode == "single": |
| | print(f"输入文件: {args.encoded_pth}") |
| | if not os.path.exists(args.encoded_pth): |
| | print(f"❌ 输入文件不存在: {args.encoded_pth}") |
| | return |
| | |
| | success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device) |
| | if success: |
| | print("✅ 单个episode解码成功") |
| | else: |
| | print("❌ 单个episode解码失败") |
| | |
| | elif args.mode == "batch": |
| | print(f"输入目录: {args.encoded_base_dir}") |
| | print(f"最大episodes: {args.max_episodes}") |
| | |
| | if not os.path.exists(args.encoded_base_dir): |
| | print(f"❌ 输入目录不存在: {args.encoded_base_dir}") |
| | return |
| | |
| | batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device) |
| |
|
| | if __name__ == "__main__": |
| | main() |