Astra / scripts /decode_openx.py
EvanEternal's picture
Upload 86 files
08bf07d verified
raw
history blame
19.4 kB
import os
import torch
import numpy as np
from PIL import Image
import imageio
import argparse
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
from tqdm import tqdm
import json
class VideoDecoder:
def __init__(self, vae_path, device="cuda"):
"""初始化视频解码器"""
self.device = device
# 初始化模型管理器
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([vae_path])
# 创建pipeline并只保留VAE
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
self.pipe = self.pipe.to(device)
# 🔧 关键修复:确保VAE及其所有组件都在正确设备上
self.pipe.vae = self.pipe.vae.to(device)
if hasattr(self.pipe.vae, 'model'):
self.pipe.vae.model = self.pipe.vae.model.to(device)
print(f"✅ VAE解码器初始化完成,设备: {device}")
def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
"""
将latents解码为视频 - 修正版本,修复维度处理问题
"""
print(f"🔧 开始解码latents...")
print(f"输入latents形状: {latents.shape}")
print(f"输入latents设备: {latents.device}")
print(f"输入latents数据类型: {latents.dtype}")
# 确保latents有batch维度
if len(latents.shape) == 4: # [C, T, H, W]
latents = latents.unsqueeze(0) # -> [1, C, T, H, W]
# 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
model_dtype = next(self.pipe.vae.parameters()).dtype
model_device = next(self.pipe.vae.parameters()).device
print(f"模型设备: {model_device}")
print(f"模型数据类型: {model_dtype}")
# 将latents移动到正确的设备和数据类型
latents = latents.to(device=model_device, dtype=model_dtype)
print(f"解码latents形状: {latents.shape}")
print(f"解码latents设备: {latents.device}")
print(f"解码latents数据类型: {latents.dtype}")
# 🔧 强制设置pipeline设备,确保所有操作在同一设备上
self.pipe.device = model_device
# 使用VAE解码
with torch.no_grad():
try:
if tiled:
print("🔧 尝试tiled解码...")
decoded_video = self.pipe.decode_video(
latents,
tiled=True,
tile_size=tile_size,
tile_stride=tile_stride
)
else:
print("🔧 使用非tiled解码...")
decoded_video = self.pipe.decode_video(latents, tiled=False)
except Exception as e:
print(f"decode_video失败,错误: {e}")
import traceback
traceback.print_exc()
# 🔧 fallback: 尝试直接调用VAE
try:
print("🔧 尝试直接调用VAE解码...")
decoded_video = self.pipe.vae.decode(
latents.squeeze(0), # 移除batch维度 [C, T, H, W]
device=model_device,
tiled=False
)
# 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
if len(decoded_video.shape) == 4: # [T, H, W, C]
decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C]
except Exception as e2:
print(f"直接VAE解码也失败: {e2}")
raise e2
print(f"解码后视频形状: {decoded_video.shape}")
# 🔧 关键修正:正确处理维度顺序
video_np = None
if len(decoded_video.shape) == 5:
# 检查不同的可能维度顺序
if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
# 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
print("🔧 检测到格式: [B, C, T, H, W]")
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
elif decoded_video.shape[1] == 3:
# 如果第二个维度是3,可能是 [B, C, T, H, W]
print("🔧 检测到可能的格式: [B, C, T, H, W]")
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
elif decoded_video.shape[-1] == 3:
# 如果最后一个维度是3,可能是 [B, T, H, W, C]
print("🔧 检测到格式: [B, T, H, W, C]")
video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C]
else:
# 尝试找到维度为3的位置
shape = list(decoded_video.shape)
if 3 in shape:
channel_dim = shape.index(3)
print(f"🔧 检测到通道维度在位置: {channel_dim}")
if channel_dim == 1: # [B, C, T, H, W]
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
elif channel_dim == 4: # [B, T, H, W, C]
video_np = decoded_video[0].to(torch.float32).cpu().numpy()
else:
print(f"⚠️ 未知的通道维度位置: {channel_dim}")
raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
else:
print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
elif len(decoded_video.shape) == 4:
# 4维张量,检查可能的格式
if decoded_video.shape[-1] == 3: # [T, H, W, C]
video_np = decoded_video.to(torch.float32).cpu().numpy()
elif decoded_video.shape[0] == 3: # [C, T, H, W]
video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
else:
print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
else:
print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
if video_np is None:
raise ValueError("Failed to convert video tensor to numpy array")
print(f"转换后视频数组形状: {video_np.shape}")
# 🔧 验证最终形状
if len(video_np.shape) != 4:
raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
if video_np.shape[-1] != 3:
print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
print(f"完整形状: {video_np.shape}")
# 尝试其他维度排列
if video_np.shape[0] == 3: # [C, T, H, W]
print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
video_np = np.transpose(video_np, (1, 2, 3, 0))
elif video_np.shape[1] == 3: # [T, C, H, W]
print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
video_np = np.transpose(video_np, (0, 2, 3, 1))
else:
raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
# 反归一化
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化
video_np = (video_np * 255).astype(np.uint8)
print(f"最终视频数组形状: {video_np.shape}")
print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
# 保存视频
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
with imageio.get_writer(output_path, fps=10, quality=8) as writer:
for frame_idx, frame in enumerate(video_np):
# 🔧 验证每一帧的形状
if len(frame.shape) != 3 or frame.shape[-1] != 3:
print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
continue
writer.append_data(frame)
if frame_idx % 10 == 0:
print(f" 写入帧 {frame_idx}/{len(video_np)}")
except Exception as e:
print(f"保存视频失败: {e}")
# 🔧 尝试保存前几帧为图片进行调试
debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
os.makedirs(debug_dir, exist_ok=True)
for i in range(min(5, len(video_np))):
frame = video_np[i]
debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
try:
if len(frame.shape) == 3 and frame.shape[-1] == 3:
Image.fromarray(frame).save(debug_path)
print(f"调试: 保存帧 {i}{debug_path}")
else:
print(f"调试: 帧 {i} 形状异常: {frame.shape}")
except Exception as e2:
print(f"调试: 保存帧 {i} 失败: {e2}")
raise e
print(f"✅ 视频保存到: {output_path}")
return video_np
def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
"""将视频帧保存为单独的图像文件"""
os.makedirs(output_dir, exist_ok=True)
for i, frame in enumerate(video_np):
frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
# 🔧 验证帧形状
if len(frame.shape) == 3 and frame.shape[-1] == 3:
Image.fromarray(frame).save(frame_path)
else:
print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")
def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
"""解码单个episode的编码数据 - 修正版本"""
print(f"\n🔧 解码episode: {encoded_pth_path}")
# 加载编码数据
try:
encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
print(f"✅ 成功加载编码数据")
except Exception as e:
print(f"❌ 加载编码数据失败: {e}")
return False
# 检查数据结构
print("🔍 编码数据结构:")
for key, value in encoded_data.items():
if isinstance(value, torch.Tensor):
print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
elif isinstance(value, dict):
print(f" - {key}: dict with keys {list(value.keys())}")
else:
print(f" - {key}: {type(value)}")
# 获取latents
latents = encoded_data.get('latents')
if latents is None:
print("❌ 未找到latents数据")
return False
# 🔧 确保latents在CPU上(加载时的默认状态)
if latents.device != torch.device('cpu'):
latents = latents.cpu()
print(f"🔧 将latents移动到CPU: {latents.device}")
episode_info = encoded_data.get('episode_info', {})
episode_idx = episode_info.get('episode_idx', 'unknown')
total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数
print(f"Episode信息:")
print(f" - Episode索引: {episode_idx}")
print(f" - Latents形状: {latents.shape}")
print(f" - Latents设备: {latents.device}")
print(f" - Latents数据类型: {latents.dtype}")
print(f" - 原始总帧数: {total_frames}")
print(f" - 压缩后帧数: {latents.shape[1]}")
# 创建输出目录
episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
output_dir = os.path.join(output_base_dir, episode_name)
os.makedirs(output_dir, exist_ok=True)
# 初始化解码器
try:
decoder = VideoDecoder(vae_path, device)
except Exception as e:
print(f"❌ 初始化解码器失败: {e}")
return False
# 解码为视频
video_output_path = os.path.join(output_dir, "decoded_video.mp4")
try:
video_np = decoder.decode_latents_to_video(
latents,
video_output_path,
tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性
tile_size=(34, 34),
tile_stride=(18, 16)
)
# 保存前几帧为图像(用于快速检查)
frames_dir = os.path.join(output_dir, "frames")
sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧
decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
# 保存解码信息
decode_info = {
"source_pth": encoded_pth_path,
"decoded_video_path": video_output_path,
"latents_shape": list(latents.shape),
"decoded_video_shape": list(video_np.shape),
"original_total_frames": total_frames,
"decoded_frames": len(video_np),
"compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
"latents_dtype": str(latents.dtype),
"latents_device": str(latents.device),
"vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
}
info_path = os.path.join(output_dir, "decode_info.json")
with open(info_path, 'w') as f:
json.dump(decode_info, f, indent=2)
print(f"✅ Episode {episode_idx} 解码完成")
print(f" - 原始帧数: {total_frames}")
print(f" - 解码帧数: {len(video_np)}")
print(f" - 压缩比: {decode_info['compression_ratio']:.2f}")
print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
return True
except Exception as e:
print(f"❌ 解码失败: {e}")
import traceback
traceback.print_exc()
return False
def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
"""批量解码episodes"""
print(f"🔧 批量解码Open-X episodes")
print(f"源目录: {encoded_base_dir}")
print(f"输出目录: {output_base_dir}")
# 查找所有编码的episodes
episode_dirs = []
if os.path.exists(encoded_base_dir):
for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性
episode_dir = os.path.join(encoded_base_dir, item)
if os.path.isdir(episode_dir):
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
if os.path.exists(encoded_path):
episode_dirs.append(encoded_path)
print(f"找到 {len(episode_dirs)} 个编码的episodes")
if max_episodes and len(episode_dirs) > max_episodes:
episode_dirs = episode_dirs[:max_episodes]
print(f"限制处理前 {max_episodes} 个episodes")
# 批量解码
success_count = 0
for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
print(f"\n{'='*60}")
print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
if success:
success_count += 1
print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
print(f"\n🎉 批量解码完成!")
print(f"总处理: {len(episode_dirs)} 个episodes")
print(f"成功解码: {success_count} 个episodes")
print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")
def main():
parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
help="解码模式:single (单个episode) 或 batch (批量)")
parser.add_argument("--encoded_pth", type=str,
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
help="单个编码文件路径(single模式)")
parser.add_argument("--encoded_base_dir", type=str,
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
help="编码数据基础目录(batch模式)")
parser.add_argument("--vae_path", type=str,
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
help="VAE模型路径")
parser.add_argument("--output_dir", type=str,
default="./decoded_results_fixed",
help="解码输出目录")
parser.add_argument("--max_episodes", type=int, default=5,
help="最大解码episodes数量(batch模式,用于测试)")
parser.add_argument("--device", type=str, default="cuda",
help="计算设备")
args = parser.parse_args()
print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
print(f"模式: {args.mode}")
print(f"VAE路径: {args.vae_path}")
print(f"输出目录: {args.output_dir}")
print(f"设备: {args.device}")
# 🔧 检查CUDA可用性
if args.device == "cuda" and not torch.cuda.is_available():
print("⚠️ CUDA不可用,切换到CPU")
args.device = "cpu"
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
if args.mode == "single":
print(f"输入文件: {args.encoded_pth}")
if not os.path.exists(args.encoded_pth):
print(f"❌ 输入文件不存在: {args.encoded_pth}")
return
success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
if success:
print("✅ 单个episode解码成功")
else:
print("❌ 单个episode解码失败")
elif args.mode == "batch":
print(f"输入目录: {args.encoded_base_dir}")
print(f"最大episodes: {args.max_episodes}")
if not os.path.exists(args.encoded_base_dir):
print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
return
batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)
if __name__ == "__main__":
main()