File size: 19,424 Bytes
08bf07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
import os
import torch
import numpy as np
from PIL import Image
import imageio
import argparse
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
from tqdm import tqdm
import json

class VideoDecoder:
    def __init__(self, vae_path, device="cuda"):
        """初始化视频解码器"""
        self.device = device
        
        # 初始化模型管理器
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
        model_manager.load_models([vae_path])
        
        # 创建pipeline并只保留VAE
        self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
        self.pipe = self.pipe.to(device)
        
        # 🔧 关键修复:确保VAE及其所有组件都在正确设备上
        self.pipe.vae = self.pipe.vae.to(device)
        if hasattr(self.pipe.vae, 'model'):
            self.pipe.vae.model = self.pipe.vae.model.to(device)
        
        print(f"✅ VAE解码器初始化完成,设备: {device}")

    def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
        """
        将latents解码为视频 - 修正版本,修复维度处理问题
        """
        print(f"🔧 开始解码latents...")
        print(f"输入latents形状: {latents.shape}")
        print(f"输入latents设备: {latents.device}")
        print(f"输入latents数据类型: {latents.dtype}")
        
        # 确保latents有batch维度
        if len(latents.shape) == 4:  # [C, T, H, W]
            latents = latents.unsqueeze(0)  # -> [1, C, T, H, W]
        
        # 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
        model_dtype = next(self.pipe.vae.parameters()).dtype
        model_device = next(self.pipe.vae.parameters()).device
        
        print(f"模型设备: {model_device}")
        print(f"模型数据类型: {model_dtype}")
        
        # 将latents移动到正确的设备和数据类型
        latents = latents.to(device=model_device, dtype=model_dtype)
        
        print(f"解码latents形状: {latents.shape}")
        print(f"解码latents设备: {latents.device}")
        print(f"解码latents数据类型: {latents.dtype}")
        
        # 🔧 强制设置pipeline设备,确保所有操作在同一设备上
        self.pipe.device = model_device
        
        # 使用VAE解码
        with torch.no_grad():
            try:
                if tiled:
                    print("🔧 尝试tiled解码...")
                    decoded_video = self.pipe.decode_video(
                        latents, 
                        tiled=True, 
                        tile_size=tile_size, 
                        tile_stride=tile_stride
                    )
                else:
                    print("🔧 使用非tiled解码...")
                    decoded_video = self.pipe.decode_video(latents, tiled=False)
                    
            except Exception as e:
                print(f"decode_video失败,错误: {e}")
                import traceback
                traceback.print_exc()
                
                # 🔧 fallback: 尝试直接调用VAE
                try:
                    print("🔧 尝试直接调用VAE解码...")
                    decoded_video = self.pipe.vae.decode(
                        latents.squeeze(0),  # 移除batch维度 [C, T, H, W]
                        device=model_device, 
                        tiled=False
                    )
                    # 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
                    if len(decoded_video.shape) == 4:  # [T, H, W, C]
                        decoded_video = decoded_video.unsqueeze(0)  # -> [1, T, H, W, C]
                except Exception as e2:
                    print(f"直接VAE解码也失败: {e2}")
                    raise e2
        
        print(f"解码后视频形状: {decoded_video.shape}")
        
        # 🔧 关键修正:正确处理维度顺序
        video_np = None
        
        if len(decoded_video.shape) == 5:
            # 检查不同的可能维度顺序
            if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
                # 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
                print("🔧 检测到格式: [B, C, T, H, W]")
                video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()  # [T, H, W, C]
            elif decoded_video.shape[1] == 3:
                # 如果第二个维度是3,可能是 [B, C, T, H, W]
                print("🔧 检测到可能的格式: [B, C, T, H, W]")
                video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()  # [T, H, W, C]
            elif decoded_video.shape[-1] == 3:
                # 如果最后一个维度是3,可能是 [B, T, H, W, C]
                print("🔧 检测到格式: [B, T, H, W, C]")
                video_np = decoded_video[0].to(torch.float32).cpu().numpy()  # [T, H, W, C]
            else:
                # 尝试找到维度为3的位置
                shape = list(decoded_video.shape)
                if 3 in shape:
                    channel_dim = shape.index(3)
                    print(f"🔧 检测到通道维度在位置: {channel_dim}")
                    
                    if channel_dim == 1:  # [B, C, T, H, W]
                        video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
                    elif channel_dim == 4:  # [B, T, H, W, C]
                        video_np = decoded_video[0].to(torch.float32).cpu().numpy()
                    else:
                        print(f"⚠️ 未知的通道维度位置: {channel_dim}")
                        raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
                else:
                    print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
                    raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
                    
        elif len(decoded_video.shape) == 4:
            # 4维张量,检查可能的格式
            if decoded_video.shape[-1] == 3:  # [T, H, W, C]
                video_np = decoded_video.to(torch.float32).cpu().numpy()
            elif decoded_video.shape[0] == 3:  # [C, T, H, W]
                video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
            else:
                print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
                raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
        else:
            print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
            raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
        
        if video_np is None:
            raise ValueError("Failed to convert video tensor to numpy array")
            
        print(f"转换后视频数组形状: {video_np.shape}")
        
        # 🔧 验证最终形状
        if len(video_np.shape) != 4:
            raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
        
        if video_np.shape[-1] != 3:
            print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
            print(f"完整形状: {video_np.shape}")
            # 尝试其他维度排列
            if video_np.shape[0] == 3:  # [C, T, H, W]
                print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
                video_np = np.transpose(video_np, (1, 2, 3, 0))
            elif video_np.shape[1] == 3:  # [T, C, H, W]
                print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
                video_np = np.transpose(video_np, (0, 2, 3, 1))
            else:
                raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
        
        # 反归一化
        video_np = (video_np * 0.5 + 0.5).clip(0, 1)  # 反归一化
        video_np = (video_np * 255).astype(np.uint8)
        
        print(f"最终视频数组形状: {video_np.shape}")
        print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
        
        # 保存视频
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        try:
            with imageio.get_writer(output_path, fps=10, quality=8) as writer:
                for frame_idx, frame in enumerate(video_np):
                    # 🔧 验证每一帧的形状
                    if len(frame.shape) != 3 or frame.shape[-1] != 3:
                        print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
                        continue
                    
                    writer.append_data(frame)
                    if frame_idx % 10 == 0:
                        print(f"  写入帧 {frame_idx}/{len(video_np)}")
        except Exception as e:
            print(f"保存视频失败: {e}")
            # 🔧 尝试保存前几帧为图片进行调试
            debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
            os.makedirs(debug_dir, exist_ok=True)
            
            for i in range(min(5, len(video_np))):
                frame = video_np[i]
                debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
                try:
                    if len(frame.shape) == 3 and frame.shape[-1] == 3:
                        Image.fromarray(frame).save(debug_path)
                        print(f"调试: 保存帧 {i}{debug_path}")
                    else:
                        print(f"调试: 帧 {i} 形状异常: {frame.shape}")
                except Exception as e2:
                    print(f"调试: 保存帧 {i} 失败: {e2}")
            raise e
        
        print(f"✅ 视频保存到: {output_path}")
        return video_np

    def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
        """将视频帧保存为单独的图像文件"""
        os.makedirs(output_dir, exist_ok=True)
        
        for i, frame in enumerate(video_np):
            frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
            # 🔧 验证帧形状
            if len(frame.shape) == 3 and frame.shape[-1] == 3:
                Image.fromarray(frame).save(frame_path)
            else:
                print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
        
        print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")

def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
    """解码单个episode的编码数据 - 修正版本"""
    print(f"\n🔧 解码episode: {encoded_pth_path}")
    
    # 加载编码数据
    try:
        encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
        print(f"✅ 成功加载编码数据")
    except Exception as e:
        print(f"❌ 加载编码数据失败: {e}")
        return False
    
    # 检查数据结构
    print("🔍 编码数据结构:")
    for key, value in encoded_data.items():
        if isinstance(value, torch.Tensor):
            print(f"  - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
        elif isinstance(value, dict):
            print(f"  - {key}: dict with keys {list(value.keys())}")
        else:
            print(f"  - {key}: {type(value)}")
    
    # 获取latents
    latents = encoded_data.get('latents')
    if latents is None:
        print("❌ 未找到latents数据")
        return False
    
    # 🔧 确保latents在CPU上(加载时的默认状态)
    if latents.device != torch.device('cpu'):
        latents = latents.cpu()
        print(f"🔧 将latents移动到CPU: {latents.device}")
    
    episode_info = encoded_data.get('episode_info', {})
    episode_idx = episode_info.get('episode_idx', 'unknown')
    total_frames = episode_info.get('total_frames', latents.shape[1] * 4)  # 估算原始帧数
    
    print(f"Episode信息:")
    print(f"  - Episode索引: {episode_idx}")
    print(f"  - Latents形状: {latents.shape}")
    print(f"  - Latents设备: {latents.device}")
    print(f"  - Latents数据类型: {latents.dtype}")
    print(f"  - 原始总帧数: {total_frames}")
    print(f"  - 压缩后帧数: {latents.shape[1]}")
    
    # 创建输出目录
    episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
    output_dir = os.path.join(output_base_dir, episode_name)
    os.makedirs(output_dir, exist_ok=True)
    
    # 初始化解码器
    try:
        decoder = VideoDecoder(vae_path, device)
    except Exception as e:
        print(f"❌ 初始化解码器失败: {e}")
        return False
    
    # 解码为视频
    video_output_path = os.path.join(output_dir, "decoded_video.mp4")
    try:
        video_np = decoder.decode_latents_to_video(
            latents, 
            video_output_path,
            tiled=False,  # 🔧 首先尝试非tiled解码,避免tiled的复杂性
            tile_size=(34, 34),
            tile_stride=(18, 16)
        )
        
        # 保存前几帧为图像(用于快速检查)
        frames_dir = os.path.join(output_dir, "frames")
        sample_frames = video_np[:min(10, len(video_np))]  # 只保存前10帧
        decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
        
        # 保存解码信息
        decode_info = {
            "source_pth": encoded_pth_path,
            "decoded_video_path": video_output_path,
            "latents_shape": list(latents.shape),
            "decoded_video_shape": list(video_np.shape),
            "original_total_frames": total_frames,
            "decoded_frames": len(video_np),
            "compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
            "latents_dtype": str(latents.dtype),
            "latents_device": str(latents.device),
            "vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
        }
        
        info_path = os.path.join(output_dir, "decode_info.json")
        with open(info_path, 'w') as f:
            json.dump(decode_info, f, indent=2)
        
        print(f"✅ Episode {episode_idx} 解码完成")
        print(f"  - 原始帧数: {total_frames}")
        print(f"  - 解码帧数: {len(video_np)}")
        print(f"  - 压缩比: {decode_info['compression_ratio']:.2f}")
        print(f"  - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
        return True
        
    except Exception as e:
        print(f"❌ 解码失败: {e}")
        import traceback
        traceback.print_exc()
        return False

def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
    """批量解码episodes"""
    print(f"🔧 批量解码Open-X episodes")
    print(f"源目录: {encoded_base_dir}")
    print(f"输出目录: {output_base_dir}")
    
    # 查找所有编码的episodes
    episode_dirs = []
    if os.path.exists(encoded_base_dir):
        for item in sorted(os.listdir(encoded_base_dir)):  # 排序确保一致性
            episode_dir = os.path.join(encoded_base_dir, item)
            if os.path.isdir(episode_dir):
                encoded_path = os.path.join(episode_dir, "encoded_video.pth")
                if os.path.exists(encoded_path):
                    episode_dirs.append(encoded_path)
    
    print(f"找到 {len(episode_dirs)} 个编码的episodes")
    
    if max_episodes and len(episode_dirs) > max_episodes:
        episode_dirs = episode_dirs[:max_episodes]
        print(f"限制处理前 {max_episodes} 个episodes")
    
    # 批量解码
    success_count = 0
    for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
        print(f"\n{'='*60}")
        print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
        
        success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
        if success:
            success_count += 1
        
        print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
    
    print(f"\n🎉 批量解码完成!")
    print(f"总处理: {len(episode_dirs)} 个episodes")
    print(f"成功解码: {success_count} 个episodes")
    print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")

def main():
    parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
    parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
                       help="解码模式:single (单个episode) 或 batch (批量)")
    parser.add_argument("--encoded_pth", type=str, 
                       default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
                       help="单个编码文件路径(single模式)")
    parser.add_argument("--encoded_base_dir", type=str,
                       default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
                       help="编码数据基础目录(batch模式)")
    parser.add_argument("--vae_path", type=str,
                       default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
                       help="VAE模型路径")
    parser.add_argument("--output_dir", type=str,
                       default="./decoded_results_fixed",
                       help="解码输出目录")
    parser.add_argument("--max_episodes", type=int, default=5,
                       help="最大解码episodes数量(batch模式,用于测试)")
    parser.add_argument("--device", type=str, default="cuda",
                       help="计算设备")
    
    args = parser.parse_args()
    
    print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
    print(f"模式: {args.mode}")
    print(f"VAE路径: {args.vae_path}")
    print(f"输出目录: {args.output_dir}")
    print(f"设备: {args.device}")
    
    # 🔧 检查CUDA可用性
    if args.device == "cuda" and not torch.cuda.is_available():
        print("⚠️ CUDA不可用,切换到CPU")
        args.device = "cpu"
    
    # 确保输出目录存在
    os.makedirs(args.output_dir, exist_ok=True)
    
    if args.mode == "single":
        print(f"输入文件: {args.encoded_pth}")
        if not os.path.exists(args.encoded_pth):
            print(f"❌ 输入文件不存在: {args.encoded_pth}")
            return
        
        success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
        if success:
            print("✅ 单个episode解码成功")
        else:
            print("❌ 单个episode解码失败")
    
    elif args.mode == "batch":
        print(f"输入目录: {args.encoded_base_dir}")
        print(f"最大episodes: {args.max_episodes}")
        
        if not os.path.exists(args.encoded_base_dir):
            print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
            return
        
        batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)

if __name__ == "__main__":
    main()